Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions ask_user_render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import ast
import json


DEFAULT_ASK_USER_QUESTION = "请提供下一步信息:"
DEFAULT_ASK_USER_INTRO = "🙋 需要你来决定下一步"
DEFAULT_ASK_USER_FOOTER = "请直接回复你的选择,或补充新的说明。"


def _truncate(text, max_len):
text = str(text or "").strip()
if max_len <= 0 or len(text) <= max_len:
return text
return text[: max_len - 3].rstrip() + "..."


def _normalize_candidates(raw_candidates):
if not isinstance(raw_candidates, (list, tuple)):
return []
candidates = []
for candidate in raw_candidates:
if candidate is None:
continue
text = str(candidate).strip()
if text:
candidates.append(text)
return candidates


def coerce_ask_user_data(data):
if not isinstance(data, dict):
return None
question = str(data.get("question") or DEFAULT_ASK_USER_QUESTION).strip() or DEFAULT_ASK_USER_QUESTION
return {"question": question, "candidates": _normalize_candidates(data.get("candidates") or [])}


def extract_ask_user_event(exit_reason):
payload = exit_reason
if isinstance(exit_reason, dict) and "result" in exit_reason and "data" in exit_reason:
if exit_reason.get("result") != "EXITED":
return None
payload = exit_reason.get("data")
if not isinstance(payload, dict):
return None
if payload.get("status") != "INTERRUPT" or payload.get("intent") != "HUMAN_INTERVENTION":
return None
return coerce_ask_user_data(payload.get("data"))


def extract_ask_user_event_from_text(raw_text):
text = str(raw_text or "").strip()
if not text:
return None
for parser in (json.loads, ast.literal_eval):
try:
parsed = parser(text)
except Exception:
continue
event = extract_ask_user_event(parsed)
if event:
return event
return None


def summarize_ask_user_event(event, max_len=120):
if not event:
return ""
question = str(event.get("question") or DEFAULT_ASK_USER_QUESTION).strip() or DEFAULT_ASK_USER_QUESTION
candidates = _normalize_candidates(event.get("candidates") or [])
summary = f"等待用户回复:{question}"
if candidates:
preview = " / ".join(candidates[:3])
if len(candidates) > 3:
preview += " / ..."
summary += f"(选项:{preview})"
return _truncate(summary, max_len)


def format_ask_user_message(event, intro=DEFAULT_ASK_USER_INTRO, footer=DEFAULT_ASK_USER_FOOTER):
normalized = coerce_ask_user_data(event)
if not normalized:
normalized = {"question": DEFAULT_ASK_USER_QUESTION, "candidates": []}
lines = [intro, "", normalized["question"]]
if normalized["candidates"]:
lines.extend(["", "可选项:"])
for idx, candidate in enumerate(normalized["candidates"], start=1):
lines.append(f"{idx}. {candidate}")
if footer:
lines.extend(["", footer])
return "\n".join(lines).strip()


def summarize_tool_args(name, args, max_len=120):
clean_args = {k: v for k, v in (args or {}).items() if not str(k).startswith("_")}
if name == "ask_user":
event = coerce_ask_user_data(clean_args)
return summarize_ask_user_event(event, max_len=max_len)
try:
rendered = json.dumps(clean_args, ensure_ascii=False)
except TypeError:
rendered = str(clean_args)
return _truncate(rendered, max_len)
23 changes: 22 additions & 1 deletion frontends/chatapp_common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import ast, asyncio, glob, json, os, queue as Q, re, socket, sys, time

from ask_user_render import extract_ask_user_event, extract_ask_user_event_from_text, format_ask_user_message

HELP_COMMANDS = (
("/help", "显示帮助"),
("/status", "查看状态"),
Expand Down Expand Up @@ -193,6 +195,9 @@ def format_restore():


def build_done_text(raw_text):
ask_user_event = extract_ask_user_event_from_text(raw_text)
if ask_user_event:
return format_ask_user_message(ask_user_event)
files = [p for p in extract_files(raw_text) if os.path.exists(p)]
body = strip_files(clean_reply(raw_text))
if files:
Expand Down Expand Up @@ -305,6 +310,17 @@ async def handle_command(self, chat_id, cmd, **ctx):
async def run_agent(self, chat_id, text, **ctx):
state = {"running": True}
self.user_tasks[chat_id] = state
ask_user_state = {}
hook_key = f"{self.source}_ask_user_{chat_id}_{time.time_ns()}"
if not hasattr(self.agent, "_turn_end_hooks"):
self.agent._turn_end_hooks = {}

def _capture_ask_user(ctx_data):
event = extract_ask_user_event((ctx_data or {}).get("exit_reason"))
if event:
ask_user_state["event"] = event

self.agent._turn_end_hooks[hook_key] = _capture_ask_user
try:
await self.send_text(chat_id, "思考中...", **ctx)
dq = self.agent.put_task(f"{FILE_HINT}\n\n{text}", source=self.source)
Expand All @@ -318,7 +334,11 @@ async def run_agent(self, chat_id, text, **ctx):
last_ping = time.time()
continue
if "done" in item:
await self.send_done(chat_id, item.get("done", ""), **ctx)
ask_user_event = ask_user_state.get("event")
if ask_user_event:
await self.send_text(chat_id, format_ask_user_message(ask_user_event), **ctx)
else:
await self.send_done(chat_id, item.get("done", ""), **ctx)
break
if not state["running"]:
await self.send_text(chat_id, "⏹️ 已停止", **ctx)
Expand All @@ -328,6 +348,7 @@ async def run_agent(self, chat_id, text, **ctx):
traceback.print_exc()
await self.send_text(chat_id, f"❌ 错误: {e}", **ctx)
finally:
self.agent._turn_end_hooks.pop(hook_key, None)
self.user_tasks.pop(chat_id, None)


Expand Down
24 changes: 23 additions & 1 deletion frontends/fsapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
sys.path.insert(0, PROJECT_ROOT)
os.chdir(PROJECT_ROOT)
from agentmain import GeneraticAgent
from ask_user_render import extract_ask_user_event, format_ask_user_message, summarize_tool_args
from frontends.chatapp_common import format_restore
from frontends.continue_cmd import handle_frontend_command as handle_continue_frontend, reset_conversation
from llmcore import mykeys
Expand Down Expand Up @@ -452,7 +453,12 @@ def _build_user_message(message):
def _fmt_tool_call(tc):
name = tc.get('tool_name', '?')
args = {k: v for k, v in (tc.get('args') or {}).items() if not k.startswith('_')}
return f"- `{name}`({json.dumps(args, ensure_ascii=False)[:200]})"
preview = summarize_tool_args(name, args, max_len=200)
if not preview:
return f"- `{name}`"
if name == 'ask_user':
return f"- `{name}` · {preview}"
return f"- `{name}`({preview})"


def _build_step_detail(resp, tool_calls):
Expand Down Expand Up @@ -551,6 +557,17 @@ def done(self, text):
self.final = (text or "_(无文本输出)_")[:self._FINAL_LIMIT]
self._push()

def ask_user(self, text):
self.status = "🙋 等待用户回复"
self.final = (text or "_(等待用户输入)_")[:self._FINAL_LIMIT]
ok, limit = self._push()
if limit:
self._rollover()
self.steps = []
self.turn_base = self.turn_no + 1
self.final = (text or "_(等待用户输入)_")[:self._FINAL_LIMIT]
self._push()

def fail(self, msg):
self.status = f"❌ {msg}"
self._push()
Expand All @@ -561,6 +578,11 @@ def _make_task_hook(card, done_event, on_final):
def hook(ctx):
try:
if ctx.get('exit_reason'):
ask_user_event = extract_ask_user_event(ctx.get('exit_reason'))
if ask_user_event:
card.ask_user(format_ask_user_message(ask_user_event))
done_event.set()
return
resp = ctx.get('response')
raw = resp.content if hasattr(resp, 'content') else str(resp)
card.done(_display_text(raw))
Expand Down
14 changes: 12 additions & 2 deletions frontends/wecomapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class TurnContext(TypedDict, total=False):

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from agentmain import GeneraticAgent
from ask_user_render import extract_ask_user_event, format_ask_user_message, summarize_tool_args
from chatapp_common import (AgentChatMixin, FILE_HINT, build_done_text, clean_reply,
ensure_single_instance, extract_files, public_access,
redirect_log, require_runtime, split_text, strip_files)
Expand Down Expand Up @@ -52,7 +53,12 @@ def _tprint(*a, **kw):
def _fmt_tool(tc):
name = tc.get("tool_name", "?")
args = {k: v for k, v in (tc.get("args") or {}).items() if not k.startswith("_")}
return f"{name}({str(args)[:120]})"
preview = summarize_tool_args(name, args, max_len=120)
if not preview:
return name
if name == "ask_user":
return f"{name}: {preview}"
return f"{name}({preview})"

# ── WeComApp ────────────────────────────────────────────────────────
class WeComApp(AgentChatMixin):
Expand Down Expand Up @@ -162,6 +168,7 @@ def _on_turn(ctx):
"""Turn-end callback injected into agent. ctx = locals() from ga.py."""
try:
if ctx.get("exit_reason"):
result["ask_user"] = extract_ask_user_event(ctx.get("exit_reason"))
resp = ctx.get("response")
result["raw"] = resp.content if hasattr(resp, "content") else str(resp)
result["summary"] = ctx.get("summary")
Expand Down Expand Up @@ -198,7 +205,10 @@ def _on_turn(ctx):

if result.get("raw") is not None:
self._stats["completed"] += 1
await self.send_done(chat_id, result["raw"])
if result.get("ask_user"):
await self.send_text(chat_id, format_ask_user_message(result["ask_user"]))
else:
await self.send_done(chat_id, result["raw"])
label = result.get("summary") or f'{len(result["raw"])} 字'
_tprint(f"[{_ts()}] ✅ Done ({chat_id}) — {label}")
elif not state["running"]:
Expand Down
9 changes: 8 additions & 1 deletion ga.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,14 @@ def turn_end_callback(self, response, tool_calls, tool_results, turn, next_promp
else:
tc = tool_calls[0]; tool_name, args = tc['tool_name'], tc['args'] # at least one because no_tool
clean_args = {k: v for k, v in args.items() if not k.startswith('_')}
summary = f"调用工具{tool_name}, args: {clean_args}"
if tool_name == 'ask_user':
question = str(clean_args.get('question') or '请提供输入:').strip() or '请提供输入:'
candidates = [str(c).strip() for c in (clean_args.get('candidates') or []) if str(c).strip()]
summary = f"等待用户回复:{question}"
if candidates:
summary += f"({len(candidates)} 个选项)"
else:
summary = f"调用工具{tool_name}, args: {clean_args}"
if tool_name == 'no_tool': summary = "直接回答了用户问题"
next_prompt += "\n\n\nUSER: <summary>呢???!\n\n"
summary = smart_format(summary.replace('\n', ''), max_str_len=80)
Expand Down
73 changes: 73 additions & 0 deletions tests/test_ask_user_render.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import sys
import unittest

ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, ROOT)

from ask_user_render import (
extract_ask_user_event,
extract_ask_user_event_from_text,
format_ask_user_message,
summarize_tool_args,
)


class AskUserRenderTests(unittest.TestCase):
def setUp(self):
self.exit_reason = {
"result": "EXITED",
"data": {
"status": "INTERRUPT",
"intent": "HUMAN_INTERVENTION",
"data": {
"question": "下一步怎么做?",
"candidates": ["继续部署", "先看日志", "回滚"],
},
},
}

def test_extract_ask_user_event(self):
event = extract_ask_user_event(self.exit_reason)
self.assertEqual(
event,
{
"question": "下一步怎么做?",
"candidates": ["继续部署", "先看日志", "回滚"],
},
)

def test_extract_ask_user_event_from_text_python_repr(self):
text = str(self.exit_reason["data"])
event = extract_ask_user_event_from_text(text)
self.assertEqual(event["question"], "下一步怎么做?")
self.assertEqual(event["candidates"][1], "先看日志")

def test_format_ask_user_message(self):
message = format_ask_user_message(extract_ask_user_event(self.exit_reason))
self.assertIn("🙋 需要你来决定下一步", message)
self.assertIn("下一步怎么做?", message)
self.assertIn("1. 继续部署", message)
self.assertIn("3. 回滚", message)

def test_summarize_tool_args_for_ask_user(self):
summary = summarize_tool_args(
"ask_user",
{"question": "选择数据库", "candidates": ["MySQL", "Postgres"]},
max_len=200,
)
self.assertIn("等待用户回复", summary)
self.assertIn("选择数据库", summary)
self.assertIn("MySQL", summary)

def test_format_from_python_repr_payload(self):
raw_text = str(self.exit_reason["data"])
event = extract_ask_user_event_from_text(raw_text)
rendered = format_ask_user_message(event)
self.assertIn("🙋 需要你来决定下一步", rendered)
self.assertIn("下一步怎么做?", rendered)
self.assertNotIn("'status': 'INTERRUPT'", rendered)


if __name__ == "__main__":
unittest.main()