Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 81 additions & 11 deletions openkb/agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,41 @@ def _make_prompt_session(session: ChatSession, style: Style, use_color: bool) ->
)


async def _run_turn(agent: Any, session: ChatSession, user_input: str, style: Style) -> None:
def _make_rich_console() -> Any:
"""Create a Rich Console with a Claude-Code-like Markdown theme."""
from rich.console import Console
from rich.theme import Theme

theme = Theme({
# Headings: bold with blue tint
"markdown.h1": "bold #5fa0e0",
"markdown.h2": "bold #5fa0e0",
"markdown.h3": "bold #7ab0e8",
"markdown.h4": "bold #8abae0",
# Code
"markdown.code": "#e8c87a on #1e1e1e",
# Links
"markdown.link": "underline #5fa0e0",
"markdown.link_url": "#5fa0e0",
# Emphasis
"markdown.bold": "bold #e0e0e0",
"markdown.italic": "italic #c0c0c0",
# Lists and block quotes
"markdown.item.bullet": "#6ac0a0",
"markdown.item.number": "#6ac0a0",
"markdown.block_quote": "italic #8a8a8a",
# Horizontal rule
"markdown.hr": "#4a4a4a",
# Paragraphs — ensure normal text is visible
"markdown.paragraph": "#d0d0d0",
})
return Console(theme=theme)


async def _run_turn(
agent: Any, session: ChatSession, user_input: str, style: Style,
*, use_color: bool = True,
) -> None:
"""Run one agent turn with streaming output and persist the new history."""
from agents import (
RawResponsesStreamEvent,
Expand All @@ -202,39 +236,75 @@ async def _run_turn(agent: Any, session: ChatSession, user_input: str, style: St

result = Runner.run_streamed(agent, new_input, max_turns=MAX_TURNS)

sys.stdout.write("\n")
sys.stdout.flush()
print()
collected: list[str] = []
last_was_text = False
need_blank_before_text = False

if use_color:
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown

console = _make_rich_console()
else:
console = None # type: ignore[assignment]

def _start_live() -> Any:
if console is None:
return None
lv = Live(console=console, vertical_overflow="visible")
lv.start()
return lv

live = _start_live()

try:
async for event in result.stream_events():
if isinstance(event, RawResponsesStreamEvent):
if isinstance(event.data, ResponseTextDeltaEvent):
text = event.data.delta
if text:
if need_blank_before_text:
sys.stdout.write("\n")
if live:
live.stop()
live = None
print()
live = _start_live()
else:
sys.stdout.write("\n")
need_blank_before_text = False
sys.stdout.write(text)
sys.stdout.flush()
collected.append(text)
last_was_text = True
if live:
live.update(Markdown("".join(collected), code_theme="monokai"))
else:
sys.stdout.write(text)
sys.stdout.flush()
elif isinstance(event, RunItemStreamEvent):
item = event.item
if item.type == "tool_call_item":
if last_was_text:
sys.stdout.write("\n")
sys.stdout.flush()
if live:
live.stop()
live = None
else:
sys.stdout.write("\n")
sys.stdout.flush()
last_was_text = False
raw = item.raw_item
name = getattr(raw, "name", "?")
args = getattr(raw, "arguments", "") or ""
if live:
live.stop()
live = None
_fmt(style, ("class:tool", _format_tool_line(name, args) + "\n"))
live = _start_live()
need_blank_before_text = True
finally:
sys.stdout.write("\n\n")
sys.stdout.flush()
if live:
live.stop()
print()

answer = "".join(collected).strip()
if not answer:
Expand Down Expand Up @@ -371,7 +441,7 @@ async def run_chat(

append_log(kb_dir / "wiki", "query", user_input)
try:
await _run_turn(agent, session, user_input, style)
await _run_turn(agent, session, user_input, style, use_color=use_color)
except KeyboardInterrupt:
_fmt(style, ("class:error", "\n[aborted]\n"))
except Exception as exc:
Expand Down
68 changes: 49 additions & 19 deletions openkb/agent/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,25 +120,55 @@ async def run_query(question: str, kb_dir: Path, model: str, stream: bool = Fals
result = await Runner.run(agent, question, max_turns=MAX_TURNS)
return result.final_output or ""

import os
use_color = sys.stdout.isatty() and not os.environ.get("NO_COLOR", "")

if use_color:
from rich.live import Live
from rich.markdown import Markdown
from openkb.agent.chat import _make_rich_console
console = _make_rich_console()
else:
console = None # type: ignore[assignment]

def _start_live() -> Live | None:
if console is None:
return None
lv = Live(console=console, vertical_overflow="visible")
lv.start()
return lv

live = _start_live()

result = Runner.run_streamed(agent, question, max_turns=MAX_TURNS)
collected = []
async for event in result.stream_events():
if isinstance(event, RawResponsesStreamEvent):
if isinstance(event.data, ResponseTextDeltaEvent):
text = event.data.delta
if text:
sys.stdout.write(text)
collected: list[str] = []
try:
async for event in result.stream_events():
if isinstance(event, RawResponsesStreamEvent):
if isinstance(event.data, ResponseTextDeltaEvent):
text = event.data.delta
if text:
collected.append(text)
if live:
live.update(Markdown("".join(collected), code_theme="monokai"))
else:
sys.stdout.write(text)
sys.stdout.flush()
elif isinstance(event, RunItemStreamEvent):
item = event.item
if item.type == "tool_call_item":
raw = item.raw_item
args = getattr(raw, "arguments", "{}")
if live:
live.stop()
live = None
sys.stdout.write(f"\n[tool call] {raw.name}({args})\n\n")
sys.stdout.flush()
collected.append(text)
elif isinstance(event, RunItemStreamEvent):
item = event.item
if item.type == "tool_call_item":
raw = item.raw_item
args = getattr(raw, "arguments", "{}")
sys.stdout.write(f"\n[tool call] {raw.name}({args})\n\n")
sys.stdout.flush()
elif item.type == "tool_call_output_item":
pass
sys.stdout.write("\n")
sys.stdout.flush()
live = _start_live()
elif item.type == "tool_call_output_item":
pass
finally:
if live:
live.stop()
print()
return "".join(collected) if collected else result.final_output or ""
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
"python-dotenv",
"json-repair",
"prompt_toolkit>=3.0",
"rich>=13.0",
]

[project.urls]
Expand Down