Skip to content

Commit bee5939

Browse files
Merge pull request #324 from scaleapi/dm/add-shell-tool-support
Add ShellTool support to TemporalStreamingModel
2 parents ced40bb + 4c26908 commit bee5939

5 files changed

Lines changed: 109 additions & 18 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,23 +25,23 @@ dependencies = [
2525
"pyyaml>=6.0.2,<7",
2626
"jsonschema>=4.23.0,<5",
2727
"jsonref>=1.1.0,<2",
28-
"temporalio>=1.18.2,<2",
28+
"temporalio>=1.26.0,<2",
2929
"aiohttp>=3.10.10,<4",
3030
"redis>=5.2.0,<6",
3131
"litellm>=1.83.0,<2",
3232
"kubernetes>=25.0.0,<36.0.0",
3333
"jinja2>=3.1.3,<4",
3434
"mcp[cli]>=1.4.1",
3535
"scale-gp>=0.1.0a59",
36-
"openai-agents==0.4.2",
36+
"openai-agents==0.14.1",
3737
"tzlocal>=5.3.1",
3838
"tzdata>=2025.2",
3939
"pytest>=8.4.0",
4040
"json_log_formatter>=1.1.1",
4141
"pytest-asyncio>=1.0.0",
4242
"scale-gp-beta>=0.1.0a20",
4343
"ipykernel>=6.29.5",
44-
"openai>=2.2,<3", # Required by openai-agents 0.4.2; litellm now supports openai 2.x (issue #13711 resolved: https://github.com/BerriAI/litellm/issues/13711)
44+
"openai>=2.2,<3", # Required by openai-agents; litellm now supports openai 2.x (issue #13711 resolved: https://github.com/BerriAI/litellm/issues/13711)
4545
"cloudpickle>=3.1.1",
4646
"datadog>=0.52.1",
4747
"ddtrace>=3.13.0",

requirements-dev.lock

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ click==8.3.1
6464
# via uvicorn
6565
cloudpickle==3.1.2
6666
# via agentex-sdk
67-
colorama==0.4.6
68-
# via griffe
6967
colorlog==6.10.1
7068
# via nox
7169
comm==0.2.3
@@ -114,7 +112,7 @@ fsspec==2026.3.0
114112
# via huggingface-hub
115113
google-auth==2.49.1
116114
# via kubernetes
117-
griffe==1.15.0
115+
griffelib==2.0.2
118116
# via openai-agents
119117
h11==0.16.0
120118
# via httpcore
@@ -194,7 +192,7 @@ langgraph-checkpoint==4.0.1
194192
# via agentex-sdk
195193
langsmith==0.7.22
196194
# via langchain-core
197-
litellm==1.82.6
195+
litellm==1.83.0
198196
# via agentex-sdk
199197
markdown-it-py==3.0.0
200198
# via rich
@@ -229,7 +227,7 @@ openai==2.30.0
229227
# via agentex-sdk
230228
# via litellm
231229
# via openai-agents
232-
openai-agents==0.4.2
230+
openai-agents==0.14.1
233231
# via agentex-sdk
234232
opentelemetry-api==1.40.0
235233
# via agentex-sdk
@@ -391,7 +389,7 @@ stack-data==0.6.3
391389
starlette==0.46.2
392390
# via fastapi
393391
# via mcp
394-
temporalio==1.24.0
392+
temporalio==1.26.0
395393
# via agentex-sdk
396394
tenacity==9.1.4
397395
# via langchain-core
@@ -478,6 +476,8 @@ wcwidth==0.6.0
478476
# via prompt-toolkit
479477
websocket-client==1.9.0
480478
# via kubernetes
479+
websockets==15.0.1
480+
# via openai-agents
481481
wrapt==2.1.2
482482
# via ddtrace
483483
xxhash==3.6.0

requirements.lock

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ click==8.3.1
6161
# via uvicorn
6262
cloudpickle==3.1.2
6363
# via agentex-sdk
64-
colorama==0.4.6
65-
# via griffe
6664
comm==0.2.3
6765
# via ipykernel
6866
cryptography==46.0.6
@@ -101,7 +99,7 @@ fsspec==2026.3.0
10199
# via huggingface-hub
102100
google-auth==2.49.1
103101
# via kubernetes
104-
griffe==1.15.0
102+
griffelib==2.0.2
105103
# via openai-agents
106104
h11==0.16.0
107105
# via httpcore
@@ -178,7 +176,7 @@ langgraph-checkpoint==4.0.1
178176
# via agentex-sdk
179177
langsmith==0.7.22
180178
# via langchain-core
181-
litellm==1.82.6
179+
litellm==1.83.0
182180
# via agentex-sdk
183181
markdown-it-py==4.0.0
184182
# via rich
@@ -207,7 +205,7 @@ openai==2.30.0
207205
# via agentex-sdk
208206
# via litellm
209207
# via openai-agents
210-
openai-agents==0.4.2
208+
openai-agents==0.14.1
211209
# via agentex-sdk
212210
opentelemetry-api==1.40.0
213211
# via agentex-sdk
@@ -359,7 +357,7 @@ stack-data==0.6.3
359357
starlette==0.46.2
360358
# via fastapi
361359
# via mcp
362-
temporalio==1.24.0
360+
temporalio==1.26.0
363361
# via agentex-sdk
364362
tenacity==9.1.4
365363
# via langchain-core
@@ -441,6 +439,8 @@ wcwidth==0.6.0
441439
# via prompt-toolkit
442440
websocket-client==1.9.0
443441
# via kubernetes
442+
websockets==15.0.1
443+
# via openai-agents
444444
wrapt==2.1.2
445445
# via ddtrace
446446
xxhash==3.6.0

src/agentex/lib/core/temporal/plugins/openai_agents/models/temporal_streaming_model.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@
2727
CodeInterpreterTool,
2828
ImageGenerationTool,
2929
)
30+
from agents.computer import Computer, AsyncComputer
31+
32+
try:
33+
from agents.tool import ShellTool # type: ignore[attr-defined]
34+
except ImportError:
35+
ShellTool = None # type: ignore[assignment,misc]
3036
from agents.usage import Usage, InputTokensDetails, OutputTokensDetails # type: ignore[attr-defined]
3137
from agents.model_settings import MCPToolChoice
3238
from openai.types.responses import (
@@ -303,11 +309,28 @@ def _convert_tools(self, tools: list[Tool], handoffs: list[Handoff]) -> tuple[Li
303309
tool_includes.append("file_search_call.results")
304310

305311
elif isinstance(tool, ComputerTool):
312+
# In newer openai-agents, tool.computer may be a factory
313+
# (ComputerCreate/ComputerProvider). Only concrete Computer
314+
# / AsyncComputer instances expose environment/dimensions.
315+
computer = tool.computer
316+
if not isinstance(computer, (Computer, AsyncComputer)):
317+
raise ValueError(
318+
"ComputerTool.computer must be a Computer or AsyncComputer "
319+
"instance for Responses API serialization; got "
320+
f"{type(computer).__name__}"
321+
)
322+
environment = computer.environment
323+
dimensions = computer.dimensions
324+
if environment is None or dimensions is None:
325+
raise ValueError(
326+
"ComputerTool requires `environment` and `dimensions` on the "
327+
"Computer/AsyncComputer implementation."
328+
)
306329
response_tools.append({
307330
"type": "computer_use_preview",
308-
"environment": tool.computer.environment,
309-
"display_width": tool.computer.dimensions[0],
310-
"display_height": tool.computer.dimensions[1],
331+
"environment": environment,
332+
"display_width": dimensions[0],
333+
"display_height": dimensions[1],
311334
})
312335

313336
elif isinstance(tool, HostedMCPTool):
@@ -326,6 +349,13 @@ def _convert_tools(self, tools: list[Tool], handoffs: list[Handoff]) -> tuple[Li
326349
"type": "local_shell",
327350
})
328351

352+
elif ShellTool is not None and isinstance(tool, ShellTool):
353+
environment = dict(tool.environment) if tool.environment else {"type": "local"}
354+
response_tools.append({
355+
"type": "shell",
356+
"environment": environment,
357+
})
358+
329359
else:
330360
logger.warning(f"Unknown tool type: {type(tool).__name__}, skipping")
331361

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""Unit tests for TemporalStreamingModel._convert_tools tool serialization."""
2+
3+
from unittest.mock import MagicMock, patch
4+
5+
import pytest
6+
7+
from agentex.lib.core.temporal.plugins.openai_agents.models import (
8+
temporal_streaming_model as tsm_module,
9+
)
10+
from agentex.lib.core.temporal.plugins.openai_agents.models.temporal_streaming_model import (
11+
TemporalStreamingModel,
12+
)
13+
14+
15+
@pytest.fixture
16+
def model():
17+
with patch(
18+
"agentex.lib.core.temporal.plugins.openai_agents.models.temporal_streaming_model.create_async_agentex_client"
19+
):
20+
return TemporalStreamingModel(model_name="gpt-4o", openai_client=MagicMock())
21+
22+
23+
class _FakeShellTool:
24+
"""Stand-in for agents.tool.ShellTool for environments where it isn't installed."""
25+
26+
def __init__(self, environment):
27+
self.environment = environment
28+
29+
30+
def test_shell_tool_local_environment(model, monkeypatch):
31+
"""ShellTool with a local environment should serialize to a 'shell' payload."""
32+
monkeypatch.setattr(tsm_module, "ShellTool", _FakeShellTool)
33+
34+
tool = _FakeShellTool(environment={"type": "local", "skills": ["git"]})
35+
response_tools, _ = model._convert_tools([tool], handoffs=[])
36+
37+
assert response_tools == [{"type": "shell", "environment": {"type": "local", "skills": ["git"]}}]
38+
39+
40+
def test_shell_tool_defaults_environment_when_missing(model, monkeypatch):
41+
"""ShellTool with environment=None should fall back to {'type': 'local'}."""
42+
monkeypatch.setattr(tsm_module, "ShellTool", _FakeShellTool)
43+
44+
tool = _FakeShellTool(environment=None)
45+
response_tools, _ = model._convert_tools([tool], handoffs=[])
46+
47+
assert response_tools == [{"type": "shell", "environment": {"type": "local"}}]
48+
49+
50+
def test_shell_tool_unavailable_falls_through(model, monkeypatch, caplog):
51+
"""If ShellTool isn't installed, an unknown tool should log a warning and be skipped."""
52+
monkeypatch.setattr(tsm_module, "ShellTool", None)
53+
54+
class _NotAShellTool:
55+
pass
56+
57+
with caplog.at_level("WARNING"):
58+
response_tools, _ = model._convert_tools([_NotAShellTool()], handoffs=[])
59+
60+
assert response_tools == []
61+
assert any("Unknown tool type" in rec.message for rec in caplog.records)

0 commit comments

Comments
 (0)