Skip to content

Commit 822dbc3

Browse files
committed
Sanitize persisted chat image history
1 parent a227740 commit 822dbc3

2 files changed

Lines changed: 151 additions & 4 deletions

File tree

openkb/agent/chat_session.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Chat session persistence for `openkb chat`.
22
3-
Each session lives in ``<kb>/.openkb/chats/<id>.json`` and stores the full
3+
Each session lives in ``<kb>/.openkb/chats/<id>.json`` and stores a sanitized
44
agent-SDK history (from ``RunResult.to_input_list()``) alongside the user
55
messages and full assistant replies kept as plain strings for display and
6-
export.
6+
export. Large tool-returned image payloads are replaced with lightweight
7+
references before the history is reused or persisted.
78
"""
89
from __future__ import annotations
910

@@ -17,6 +18,11 @@
1718
from typing import Any
1819

1920

21+
_IMAGE_HISTORY_NOTE = (
22+
"Image output omitted from chat history to avoid persisting raw data URLs."
23+
)
24+
25+
2026
def _utcnow_iso() -> str:
2127
return datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
2228

@@ -38,6 +44,71 @@ def _title_from(msg: str, limit: int = 60) -> str:
3844
return msg[: limit - 1] + "\u2026"
3945

4046

47+
def _image_history_placeholder(image_path: str | None) -> dict[str, str]:
48+
text = _IMAGE_HISTORY_NOTE
49+
if image_path:
50+
text += f" Source path: {image_path}."
51+
text += " Call get_image again if you need to inspect it."
52+
return {"type": "input_text", "text": text}
53+
54+
55+
def _extract_get_image_path(item: dict[str, Any]) -> str | None:
56+
if item.get("type") != "function_call" or item.get("name") != "get_image":
57+
return None
58+
arguments = item.get("arguments")
59+
if not isinstance(arguments, str):
60+
return None
61+
try:
62+
payload = json.loads(arguments)
63+
except json.JSONDecodeError:
64+
return None
65+
image_path = payload.get("image_path")
66+
if isinstance(image_path, str) and image_path:
67+
return image_path
68+
return None
69+
70+
71+
def _sanitize_history_value(value: Any, image_path: str | None = None) -> Any:
72+
if isinstance(value, list):
73+
return [_sanitize_history_value(item, image_path) for item in value]
74+
if not isinstance(value, dict):
75+
return value
76+
77+
if value.get("type") == "input_image":
78+
image_url = value.get("image_url")
79+
if isinstance(image_url, str) and image_url.startswith("data:"):
80+
return _image_history_placeholder(image_path)
81+
82+
return {
83+
key: _sanitize_history_value(item, image_path)
84+
for key, item in value.items()
85+
}
86+
87+
88+
def sanitize_history(history: list[dict[str, Any]]) -> list[dict[str, Any]]:
89+
"""Strip large image payloads from model history while keeping a re-fetch hint."""
90+
image_paths_by_call_id: dict[str, str] = {}
91+
sanitized: list[dict[str, Any]] = []
92+
93+
for item in history:
94+
if not isinstance(item, dict):
95+
sanitized.append(item)
96+
continue
97+
98+
image_path = _extract_get_image_path(item)
99+
call_id = item.get("call_id")
100+
if image_path and isinstance(call_id, str):
101+
image_paths_by_call_id[call_id] = image_path
102+
103+
history_image_path = None
104+
if item.get("type") == "function_call_output" and isinstance(call_id, str):
105+
history_image_path = image_paths_by_call_id.get(call_id)
106+
107+
sanitized.append(_sanitize_history_value(item, history_image_path))
108+
109+
return sanitized
110+
111+
41112
@dataclass
42113
class ChatSession:
43114
id: str
@@ -99,7 +170,7 @@ def record_turn(
99170
assistant_text: str,
100171
new_history: list[dict[str, Any]],
101172
) -> None:
102-
self.history = new_history
173+
self.history = sanitize_history(new_history)
103174
self.user_turns.append(user_message)
104175
self.assistant_texts.append(assistant_text)
105176
self.turn_count = len(self.user_turns)
@@ -120,7 +191,7 @@ def load_session(kb_dir: Path, session_id: str) -> ChatSession:
120191
language=data.get("language", "en"),
121192
title=data.get("title", ""),
122193
turn_count=data.get("turn_count", 0),
123-
history=data.get("history", []),
194+
history=sanitize_history(data.get("history", [])),
124195
user_turns=data.get("user_turns", []),
125196
assistant_texts=data.get("assistant_texts", []),
126197
path=path,

tests/test_chat_session.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Tests for chat session persistence."""
2+
from __future__ import annotations
3+
4+
import json
5+
6+
from openkb.agent.chat_session import ChatSession, load_session
7+
8+
9+
def _image_history() -> list[dict[str, object]]:
10+
return [
11+
{"role": "user", "content": "Describe the diagram."},
12+
{
13+
"type": "function_call",
14+
"call_id": "call_123",
15+
"name": "get_image",
16+
"arguments": '{"image_path":"sources/images/doc/figure-1.png"}',
17+
},
18+
{
19+
"type": "function_call_output",
20+
"call_id": "call_123",
21+
"output": [
22+
{
23+
"type": "input_image",
24+
"image_url": "data:image/png;base64,AAAA",
25+
}
26+
],
27+
},
28+
]
29+
30+
31+
def test_record_turn_replaces_data_image_with_text_reference(tmp_path):
32+
session = ChatSession.new(tmp_path, "gpt-4o-mini", "en")
33+
34+
session.record_turn(
35+
"Describe the diagram.",
36+
"It is a flow chart.",
37+
_image_history(),
38+
)
39+
40+
saved = json.loads(session.path.read_text(encoding="utf-8"))
41+
output_part = saved["history"][2]["output"][0]
42+
43+
assert output_part["type"] == "input_text"
44+
assert "data:image/png;base64,AAAA" not in session.path.read_text(encoding="utf-8")
45+
assert "sources/images/doc/figure-1.png" in output_part["text"]
46+
assert "Call get_image again" in output_part["text"]
47+
48+
49+
def test_load_session_sanitizes_legacy_image_history(tmp_path):
50+
session = ChatSession.new(tmp_path, "gpt-4o-mini", "en")
51+
raw_history = _image_history()
52+
session.path.parent.mkdir(parents=True, exist_ok=True)
53+
session.path.write_text(
54+
json.dumps(
55+
{
56+
"id": session.id,
57+
"created_at": session.created_at,
58+
"updated_at": session.updated_at,
59+
"model": session.model,
60+
"language": session.language,
61+
"title": "",
62+
"turn_count": 1,
63+
"history": raw_history,
64+
"user_turns": ["Describe the diagram."],
65+
"assistant_texts": ["It is a flow chart."],
66+
}
67+
),
68+
encoding="utf-8",
69+
)
70+
71+
loaded = load_session(tmp_path, session.id)
72+
73+
output_part = loaded.history[2]["output"][0]
74+
assert output_part["type"] == "input_text"
75+
assert "data:image/png;base64,AAAA" not in output_part["text"]
76+
assert "sources/images/doc/figure-1.png" in output_part["text"]

0 commit comments

Comments
 (0)