Skip to content

Commit 0c16595

Browse files
declan-scalestainless-app[bot]
authored andcommitted
feat(adk): Revamp run_claude_agent_activity to use more streaming (#309)
1 parent 464cd2d commit 0c16595

4 files changed

Lines changed: 1073 additions & 232 deletions

File tree

src/agentex/lib/core/temporal/plugins/claude_agents/activities.py

Lines changed: 235 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
"""Temporal activities for Claude Agents SDK integration."""
1+
"""Temporal activities for Claude Agents SDK integration.
2+
3+
Processes all content blocks from the AssistantMessage stream in iteration order
4+
(TextBlock, ThinkingBlock, ToolUseBlock) with correct timestamps. Tool results
5+
come from PostToolUse/PostToolUseFailure hooks which fire between message yields.
6+
"""
27

38
from __future__ import annotations
49

@@ -8,10 +13,25 @@
813

914
from temporalio import activity
1015
from claude_agent_sdk import AgentDefinition, ClaudeSDKClient, ClaudeAgentOptions
16+
from claude_agent_sdk.types import (
17+
HookEvent,
18+
TextBlock,
19+
HookMatcher,
20+
ToolUseBlock,
21+
ResultMessage,
22+
SystemMessage,
23+
ThinkingBlock,
24+
AssistantMessage,
25+
)
1126

27+
from agentex.lib import adk
28+
from agentex.types.text_delta import TextDelta
1229
from agentex.lib.utils.logging import make_logger
13-
from agentex.lib.core.temporal.plugins.claude_agents.hooks import create_streaming_hooks
14-
from agentex.lib.core.temporal.plugins.claude_agents.message_handler import ClaudeMessageHandler
30+
from agentex.types.text_content import TextContent
31+
from agentex.types.reasoning_content import ReasoningContent
32+
from agentex.types.task_message_update import StreamTaskMessageFull, StreamTaskMessageDelta
33+
from agentex.types.tool_request_content import ToolRequestContent
34+
from agentex.lib.core.temporal.plugins.claude_agents.hooks.hooks import create_streaming_hooks
1535
from agentex.lib.core.temporal.plugins.openai_agents.interceptors.context_interceptor import (
1636
streaming_task_id,
1737
streaming_trace_id,
@@ -72,10 +92,10 @@ def _reconstruct_agent_defs(agents: dict[str, Any] | None) -> dict[str, AgentDef
7292
agent_defs[name] = agent_data
7393
else:
7494
agent_defs[name] = AgentDefinition(
75-
description=agent_data.get('description', ''),
76-
prompt=agent_data.get('prompt', ''),
77-
tools=agent_data.get('tools'),
78-
model=agent_data.get('model'),
95+
description=agent_data.get("description", ""),
96+
prompt=agent_data.get("prompt", ""),
97+
tools=agent_data.get("tools"),
98+
model=agent_data.get("model"),
7999
)
80100
return agent_defs
81101

@@ -115,12 +135,11 @@ async def run_claude_agent_activity(
115135
) -> dict[str, Any]:
116136
"""Execute Claude SDK - wrapped in Temporal activity.
117137
118-
This activity:
119-
1. Gets task_id from ContextVar (set by ContextInterceptor)
120-
2. Configures Claude with workspace isolation and session resume
121-
3. Runs Claude SDK and processes messages via ClaudeMessageHandler
122-
4. Streams messages to UI in real-time
123-
5. Returns session_id, usage, and cost for next turn
138+
Streams all content block types to the Agentex UI:
139+
- TextBlock → streamed as text deltas (from message stream)
140+
- ThinkingBlock → streamed as ReasoningContent (from message stream)
141+
- ToolUseBlock → streamed as tool_request (from message stream)
142+
- Tool results → streamed as tool_response (from PostToolUse hook)
124143
125144
Args:
126145
prompt: User message to send to Claude
@@ -157,15 +176,19 @@ async def run_claude_agent_activity(
157176
agent_defs = _reconstruct_agent_defs(agents)
158177

159178
# Only include explicit params that were actually supplied (non-None),
160-
# so claude_options values for system_prompt/resume/agents are not masked.
161-
explicit_params: dict[str, Any] = {k: v for k, v in {
162-
"cwd": workspace_path,
163-
"allowed_tools": allowed_tools,
164-
"permission_mode": permission_mode,
165-
"system_prompt": system_prompt,
166-
"resume": resume_session_id,
167-
"agents": agent_defs,
168-
}.items() if v is not None}
179+
# so claude_options values are not masked.
180+
explicit_params: dict[str, Any] = {
181+
k: v
182+
for k, v in {
183+
"cwd": workspace_path,
184+
"allowed_tools": allowed_tools,
185+
"permission_mode": permission_mode,
186+
"system_prompt": system_prompt,
187+
"resume": resume_session_id,
188+
"agents": agent_defs,
189+
}.items()
190+
if v is not None
191+
}
169192

170193
# Merge in any additional claude_options (explicit params take precedence)
171194
if claude_options:
@@ -176,61 +199,219 @@ async def run_claude_agent_activity(
176199
else:
177200
options_dict = explicit_params
178201

179-
# Apply default for permission_mode if neither source supplied a value
180202
if "permission_mode" not in options_dict:
181203
options_dict["permission_mode"] = "acceptEdits"
182204

183-
# Create hooks for streaming tool calls and subagent execution
184-
streaming_hooks = create_streaming_hooks(
205+
# Shared subagent span tracking — hooks and message-level streaming both use this
206+
subagent_spans: dict[str, Any] = {}
207+
208+
# PreToolUse: auto-allow permissions
209+
# PostToolUse/PostToolUseFailure: stream tool results (richer than ToolResultBlock)
210+
# Subagent spans tracked for Task tool tracing
211+
activity_hooks: dict[HookEvent, list[HookMatcher]] = create_streaming_hooks(
185212
task_id=task_id,
186213
trace_id=trace_id,
187214
parent_span_id=parent_span_id,
215+
subagent_spans=subagent_spans,
188216
)
189217

190-
# Merge streaming hooks with any user-provided hooks from claude_options
218+
# Merge with any user-provided hooks from claude_options
191219
user_hooks = options_dict.pop("hooks", None)
192220
if user_hooks:
193-
merged_hooks = dict(streaming_hooks)
194221
for event, matchers in user_hooks.items():
195-
if event in merged_hooks:
196-
merged_hooks[event] = merged_hooks[event] + matchers
222+
if event in activity_hooks:
223+
activity_hooks[event] = activity_hooks[event] + matchers # type: ignore[operator]
197224
else:
198-
merged_hooks[event] = matchers
199-
options_dict["hooks"] = merged_hooks
200-
else:
201-
options_dict["hooks"] = streaming_hooks
225+
activity_hooks[event] = matchers # type: ignore[assignment]
202226

203-
# Construct ClaudeAgentOptions — any SDK field works via claude_options
227+
options_dict["hooks"] = activity_hooks
204228
options = ClaudeAgentOptions(**options_dict)
205229

206-
# Create message handler for streaming
207-
handler = ClaudeMessageHandler(
208-
task_id=task_id,
209-
trace_id=trace_id,
210-
parent_span_id=parent_span_id,
211-
)
230+
text_streaming_cm: Any = None # the context manager itself
231+
text_streaming_ctx: Any = None # the value returned by __aenter__
232+
session_id: str | None = None
233+
usage_info: dict[str, Any] | None = None
234+
cost_info: float | None = None
235+
serialized_messages: list[dict[str, Any]] = []
236+
237+
async def close_text_stream() -> None:
238+
nonlocal text_streaming_cm, text_streaming_ctx
239+
if text_streaming_ctx and text_streaming_cm:
240+
try:
241+
await text_streaming_cm.__aexit__(None, None, None)
242+
except Exception as e:
243+
logger.warning(f"Failed to close text stream: {e}")
244+
text_streaming_cm = None
245+
text_streaming_ctx = None
246+
247+
async def ensure_text_stream() -> Any:
248+
nonlocal text_streaming_cm, text_streaming_ctx
249+
if text_streaming_ctx is None and task_id:
250+
text_streaming_cm = adk.streaming.streaming_task_message_context(
251+
task_id=task_id,
252+
initial_content=TextContent(author="agent", content="", format="markdown"),
253+
)
254+
text_streaming_ctx = await text_streaming_cm.__aenter__()
255+
return text_streaming_ctx
256+
257+
async def stream_text_delta(text: str) -> None:
258+
if not text:
259+
return
260+
ctx = await ensure_text_stream()
261+
if not ctx:
262+
return
263+
try:
264+
await ctx.stream_update(
265+
StreamTaskMessageDelta(
266+
parent_task_message=ctx.task_message,
267+
delta=TextDelta(type="text", text_delta=text),
268+
type="delta",
269+
)
270+
)
271+
except Exception as e:
272+
logger.warning(f"Failed to stream text delta: {e}")
273+
274+
async def stream_tool_request(block: ToolUseBlock) -> None:
275+
await close_text_stream()
276+
277+
# Subagent tracing
278+
if block.name == "Task" and trace_id and parent_span_id:
279+
subagent_type = block.input.get("subagent_type", "unknown")
280+
logger.info(f"Subagent started: {subagent_type}")
281+
subagent_ctx = adk.tracing.span(
282+
trace_id=trace_id,
283+
parent_id=parent_span_id,
284+
name=f"Subagent: {subagent_type}",
285+
input=block.input,
286+
)
287+
subagent_span = await subagent_ctx.__aenter__()
288+
subagent_spans[block.id] = (subagent_ctx, subagent_span)
289+
290+
if not task_id:
291+
return
292+
try:
293+
async with adk.streaming.streaming_task_message_context(
294+
task_id=task_id,
295+
initial_content=ToolRequestContent(
296+
author="agent",
297+
name=block.name,
298+
arguments=block.input,
299+
tool_call_id=block.id,
300+
),
301+
) as ctx:
302+
await ctx.stream_update(
303+
StreamTaskMessageFull(
304+
parent_task_message=ctx.task_message,
305+
content=ToolRequestContent(
306+
author="agent",
307+
name=block.name,
308+
arguments=block.input,
309+
tool_call_id=block.id,
310+
),
311+
type="full",
312+
)
313+
)
314+
except Exception as e:
315+
logger.warning(f"Failed to stream tool request: {e}")
316+
317+
async def stream_reasoning(block: ThinkingBlock) -> None:
318+
if not task_id or not block.thinking:
319+
return
320+
lines = block.thinking.strip().split("\n", 1)
321+
summary = [lines[0]]
322+
content = ReasoningContent(
323+
author="agent",
324+
summary=summary,
325+
content=[block.thinking],
326+
style="static",
327+
type="reasoning",
328+
)
329+
try:
330+
async with adk.streaming.streaming_task_message_context(
331+
task_id=task_id,
332+
initial_content=content,
333+
) as ctx:
334+
await ctx.stream_update(
335+
StreamTaskMessageFull(
336+
parent_task_message=ctx.task_message,
337+
content=content,
338+
type="full",
339+
)
340+
)
341+
except Exception as e:
342+
logger.warning(f"Failed to stream reasoning: {e}")
343+
344+
async def handle_assistant_message(message: AssistantMessage) -> None:
345+
text_parts: list[str] = []
346+
for block in message.content:
347+
if isinstance(block, TextBlock):
348+
await stream_text_delta(block.text)
349+
if block.text:
350+
text_parts.append(block.text)
351+
352+
elif isinstance(block, ThinkingBlock):
353+
if block.thinking:
354+
await close_text_stream()
355+
await stream_reasoning(block)
356+
357+
elif isinstance(block, ToolUseBlock):
358+
await stream_tool_request(block)
359+
360+
# ToolResultBlock skipped — tool results come from PostToolUse hook
361+
362+
if text_parts:
363+
serialized_messages.append(
364+
{
365+
"role": "assistant",
366+
"content": "\n".join(text_parts),
367+
}
368+
)
212369

213-
# Run Claude and process messages
214-
try:
215-
await handler.initialize()
370+
async def handle_system_message(message: SystemMessage) -> None:
371+
nonlocal session_id
372+
if message.subtype == "init":
373+
session_id = message.data.get("session_id")
374+
logger.debug(f"Session initialized: {session_id[:16] if session_id else 'unknown'}...")
375+
376+
async def handle_result_message(message: ResultMessage) -> None:
377+
nonlocal session_id, usage_info, cost_info
378+
usage_info = message.usage
379+
cost_info = message.total_cost_usd
380+
if message.session_id:
381+
session_id = message.session_id
382+
logger.info(f"Cost: ${cost_info:.4f}, Duration: {message.duration_ms}ms, Turns: {message.num_turns}")
216383

384+
try:
217385
async with ClaudeSDKClient(options=options) as client:
218386
await client.query(prompt)
219-
220-
# Use receive_response() instead of receive_messages()
221-
# receive_response() yields messages until ResultMessage, then stops
222-
# receive_messages() is infinite and never completes!
223387
async for message in client.receive_response():
224-
await handler.handle_message(message)
225-
226-
logger.debug(f"Message loop completed, cleaning up...")
227-
await handler.cleanup()
228-
229-
results = handler.get_results()
388+
if isinstance(message, AssistantMessage):
389+
await handle_assistant_message(message)
390+
elif isinstance(message, SystemMessage):
391+
await handle_system_message(message)
392+
elif isinstance(message, ResultMessage):
393+
await handle_result_message(message)
394+
395+
logger.debug("Message loop completed, cleaning up...")
396+
await close_text_stream()
397+
398+
results = {
399+
"messages": serialized_messages,
400+
"task_id": task_id,
401+
"session_id": session_id,
402+
"usage": usage_info,
403+
"cost_usd": cost_info,
404+
}
230405
logger.debug(f"Returning results with keys: {results.keys()}")
231406
return results
232407

233408
except Exception as e:
234409
logger.error(f"[run_claude_agent_activity] Error: {e}", exc_info=True)
235-
await handler.cleanup()
410+
await close_text_stream()
411+
for _ctx, _span in list(subagent_spans.values()):
412+
try:
413+
await _ctx.__aexit__(None, None, None)
414+
except Exception:
415+
pass
416+
subagent_spans.clear()
236417
raise

0 commit comments

Comments
 (0)