Skip to content

Commit 7e50d3a

Browse files
authored
Token optimization (#368)
* Token optmzation * Add tests
1 parent 2ab5bf6 commit 7e50d3a

7 files changed

Lines changed: 256 additions & 9 deletions

File tree

services/chatbot/src/chatbot/agent_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import logging
23

34
from langchain.agents import AgentState
45
from langchain.agents.middleware.types import before_model
@@ -7,7 +8,11 @@
78

89
from .config import Config
910

11+
logger = logging.getLogger(__name__)
12+
1013
INDIVIDUAL_MIN_LENGTH = 100
14+
# Approximate characters per token across providers
15+
CHARS_PER_TOKEN = 4
1116

1217

1318
def collect_long_strings(obj):
@@ -88,3 +93,53 @@ def truncate_tool_messages(state: AgentState, runtime: Runtime) -> AgentState:
8893
else:
8994
modified_messages.append(msg)
9095
return {"messages": modified_messages}
96+
97+
98+
def _estimate_tokens(text):
99+
"""Estimate token count using character-based approximation."""
100+
return len(text) // CHARS_PER_TOKEN
101+
102+
103+
def _message_content(msg):
104+
"""Extract text content from a message dict or object."""
105+
if isinstance(msg, dict):
106+
return msg.get("content", "")
107+
return getattr(msg, "content", "")
108+
109+
110+
def trim_messages_to_token_limit(messages):
111+
"""
112+
Trim conversation history from the oldest messages to fit within the token
113+
budget derived from MAX_CONTENT_LENGTH.
114+
The most recent message (the new user turn) is always kept.
115+
"""
116+
max_tokens = Config.MAX_CONTENT_LENGTH // CHARS_PER_TOKEN
117+
118+
if not messages:
119+
return messages
120+
121+
# Estimate per-message tokens
122+
token_counts = [_estimate_tokens(_message_content(m)) for m in messages]
123+
total_tokens = sum(token_counts)
124+
125+
if total_tokens <= max_tokens:
126+
return messages
127+
128+
# Always keep the last message; trim from the front
129+
trimmed = list(messages)
130+
trimmed_tokens = list(token_counts)
131+
132+
while len(trimmed) > 1 and sum(trimmed_tokens) > max_tokens:
133+
trimmed.pop(0)
134+
trimmed_tokens.pop(0)
135+
136+
logger.info(
137+
"Trimmed conversation history from %d to %d messages "
138+
"(estimated tokens: %d -> %d, limit: %d)",
139+
len(messages),
140+
len(trimmed),
141+
total_tokens,
142+
sum(trimmed_tokens),
143+
max_tokens,
144+
)
145+
return trimmed

services/chatbot/src/chatbot/chat_api.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from quart import Blueprint, jsonify, request
66

7+
from .agent_utils import trim_messages_to_token_limit
78
from .chat_service import (delete_chat_history, get_chat_history,
89
process_user_message)
910
from .config import Config
@@ -229,8 +230,7 @@ async def state():
229230
"Provider API key for session %s: %s", session_id, provider_api_key[:5]
230231
)
231232
chat_history = await get_chat_history(session_id)
232-
# Limit chat history to last 20 messages
233-
chat_history = chat_history[-20:]
233+
chat_history = trim_messages_to_token_limit(chat_history)
234234
return (
235235
jsonify(
236236
{
@@ -259,16 +259,15 @@ async def history():
259259
provider_api_key = await get_api_key(session_id)
260260
if provider in {"openai", "anthropic"} and provider_api_key:
261261
chat_history = await get_chat_history(session_id)
262-
# Limit chat history to last 20 messages
263-
chat_history = chat_history[-20:]
262+
chat_history = trim_messages_to_token_limit(chat_history)
264263
return jsonify({"chat_history": chat_history}), 200
265264
if provider in {"openai", "anthropic"}:
266265
return (
267266
jsonify({"chat_history": []}),
268267
200,
269268
)
270269
chat_history = await get_chat_history(session_id)
271-
chat_history = chat_history[-20:] if chat_history else []
270+
chat_history = trim_messages_to_token_limit(chat_history) if chat_history else []
272271
return jsonify({"chat_history": chat_history}), 200
273272

274273

services/chatbot/src/chatbot/chat_service.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from langgraph.graph.message import Messages
55

6+
from .agent_utils import trim_messages_to_token_limit
67
from .config import Config
78
from .extensions import db
89
from .langgraph_agent import execute_langgraph_agent
@@ -80,8 +81,7 @@ async def process_user_message(session_id, user_message, api_key, model_name, us
8081
)
8182
logger.debug("Added messages to Chroma collection - session_id: %s", session_id)
8283

83-
# Limit chat history to last 20 messages
84-
history = history[-20:]
84+
history = trim_messages_to_token_limit(history)
8585
await update_chat_history(session_id, history)
8686
logger.info(
8787
"Message processing complete - session_id: %s, response_id: %s, history_count: %d",

services/chatbot/src/chatbot/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@ class Config:
3434
AWS_ROLE_SESSION_NAME = os.getenv("AWS_ROLE_SESSION_NAME", "crapi-chatbot-session")
3535
VERTEX_PROJECT = os.getenv("VERTEX_PROJECT", "")
3636
VERTEX_LOCATION = os.getenv("VERTEX_LOCATION", "")
37-
MAX_CONTENT_LENGTH = int(os.getenv("MAX_CONTENT_LENGTH", 50000))
37+
MAX_CONTENT_LENGTH = int(os.getenv("MAX_CONTENT_LENGTH", 100000))
3838
CHROMA_HOST = CHROMA_HOST
3939
CHROMA_PORT = CHROMA_PORT

services/chatbot/src/chatbot/langgraph_agent.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from langchain_mistralai import ChatMistralAI
1212
from langchain_openai import AzureChatOpenAI, ChatOpenAI
1313

14-
from .agent_utils import truncate_tool_messages
14+
from .agent_utils import trim_messages_to_token_limit, truncate_tool_messages
1515
from .aws_credentials import get_bedrock_credentials_kwargs
1616
from .config import Config
1717
from .extensions import postgresdb
@@ -263,6 +263,7 @@ async def execute_langgraph_agent(
263263
len(messages),
264264
)
265265
agent = await build_langgraph_agent(api_key, model_name, user_jwt)
266+
messages = trim_messages_to_token_limit(messages)
266267
logger.debug("Invoking agent with %d messages", len(messages))
267268
response = await agent.ainvoke({"messages": messages})
268269
logger.info(

services/chatbot/src/chatbot/tests/__init__.py

Whitespace-only changes.
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
"""Tests for trim_messages_to_token_limit and supporting helpers."""
2+
import sys
3+
from types import ModuleType
4+
from unittest.mock import MagicMock, patch
5+
6+
import pytest
7+
8+
# ---------------------------------------------------------------------------
9+
# Stub out heavy third-party deps so the module can be imported without them.
10+
# ---------------------------------------------------------------------------
11+
_STUBS = {}
12+
for mod_name in [
13+
"langchain", "langchain.agents", "langchain.agents.middleware",
14+
"langchain.agents.middleware.types",
15+
"langchain_core", "langchain_core.messages",
16+
"langgraph", "langgraph.runtime",
17+
"motor", "motor.motor_asyncio",
18+
"langchain_community", "langchain_community.agent_toolkits",
19+
"langchain_community.utilities",
20+
"pymongo",
21+
]:
22+
if mod_name not in sys.modules:
23+
stub = ModuleType(mod_name)
24+
sys.modules[mod_name] = stub
25+
_STUBS[mod_name] = stub
26+
27+
# Provide the decorator used by agent_utils at import time
28+
sys.modules["langchain.agents"].AgentState = dict
29+
sys.modules["langchain.agents.middleware.types"].before_model = (
30+
lambda **kw: (lambda fn: fn)
31+
)
32+
sys.modules["langchain_core.messages"].ToolMessage = type("ToolMessage", (), {})
33+
sys.modules["langgraph.runtime"].Runtime = type("Runtime", (), {})
34+
35+
# Stub dotenv
36+
dotenv_stub = ModuleType("dotenv")
37+
dotenv_stub.load_dotenv = lambda *a, **kw: None
38+
sys.modules["dotenv"] = dotenv_stub
39+
40+
# Stub dbconnections before config is imported
41+
db_stub = ModuleType("chatbot.dbconnections")
42+
db_stub.CHROMA_HOST = "localhost"
43+
db_stub.CHROMA_PORT = 8000
44+
db_stub.MONGO_CONNECTION_URI = "mongodb://localhost"
45+
db_stub.POSTGRES_URI = "postgresql://localhost"
46+
sys.modules["chatbot.dbconnections"] = db_stub
47+
48+
# Now safe to import the module under test
49+
from chatbot.agent_utils import (
50+
CHARS_PER_TOKEN,
51+
_estimate_tokens,
52+
_message_content,
53+
trim_messages_to_token_limit,
54+
)
55+
56+
57+
# ---------------------------------------------------------------------------
58+
# Helpers
59+
# ---------------------------------------------------------------------------
60+
61+
def _make_msg(role, content):
62+
"""Return a plain dict message like those stored in chat history."""
63+
return {"role": role, "content": content}
64+
65+
66+
# ---------------------------------------------------------------------------
67+
# _estimate_tokens
68+
# ---------------------------------------------------------------------------
69+
70+
class TestEstimateTokens:
71+
def test_empty_string(self):
72+
assert _estimate_tokens("") == 0
73+
74+
def test_known_length(self):
75+
text = "a" * 400 # 400 chars -> 100 tokens
76+
assert _estimate_tokens(text) == 400 // CHARS_PER_TOKEN
77+
78+
def test_short_string(self):
79+
assert _estimate_tokens("hi") == 0 # 2 // 4 == 0
80+
81+
82+
# ---------------------------------------------------------------------------
83+
# _message_content
84+
# ---------------------------------------------------------------------------
85+
86+
class TestMessageContent:
87+
def test_dict_message(self):
88+
assert _message_content({"role": "user", "content": "hello"}) == "hello"
89+
90+
def test_dict_missing_content(self):
91+
assert _message_content({"role": "user"}) == ""
92+
93+
def test_object_message(self):
94+
class Msg:
95+
content = "from object"
96+
assert _message_content(Msg()) == "from object"
97+
98+
def test_object_no_content(self):
99+
class Msg:
100+
pass
101+
assert _message_content(Msg()) == ""
102+
103+
104+
# ---------------------------------------------------------------------------
105+
# trim_messages_to_token_limit
106+
# ---------------------------------------------------------------------------
107+
108+
MAX_CONTENT_LENGTH = 100000 # default
109+
110+
111+
class TestTrimMessagesToTokenLimit:
112+
"""Tests use a patched MAX_CONTENT_LENGTH to keep fixtures small."""
113+
114+
@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400)
115+
def test_under_limit_returns_all(self):
116+
"""Messages totalling fewer tokens than the budget are untouched."""
117+
msgs = [_make_msg("user", "a" * 100), _make_msg("assistant", "b" * 100)]
118+
result = trim_messages_to_token_limit(msgs)
119+
assert len(result) == 2
120+
assert result == msgs
121+
122+
@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400)
123+
def test_over_limit_trims_oldest(self):
124+
"""Oldest messages are dropped first to fit within budget."""
125+
# budget = 400 // 4 = 100 tokens
126+
# Each message = 200 chars = 50 tokens -> 3 msgs = 150 tokens > 100
127+
msgs = [
128+
_make_msg("user", "a" * 200),
129+
_make_msg("assistant", "b" * 200),
130+
_make_msg("user", "c" * 200),
131+
]
132+
result = trim_messages_to_token_limit(msgs)
133+
assert len(result) < 3
134+
# Last message is always preserved
135+
assert result[-1]["content"] == "c" * 200
136+
137+
@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400)
138+
def test_last_message_always_kept(self):
139+
"""Even if a single message exceeds the budget, it is kept."""
140+
msgs = [_make_msg("user", "x" * 800)]
141+
result = trim_messages_to_token_limit(msgs)
142+
assert len(result) == 1
143+
assert result[0]["content"] == "x" * 800
144+
145+
@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400)
146+
def test_trims_from_front_not_back(self):
147+
"""Verify older messages (front) are removed, newer ones (back) stay."""
148+
# budget = 100 tokens; each msg = 50 tokens
149+
msgs = [
150+
_make_msg("user", "first-" + "a" * 194),
151+
_make_msg("assistant", "second-" + "b" * 193),
152+
_make_msg("user", "third-" + "c" * 194),
153+
]
154+
result = trim_messages_to_token_limit(msgs)
155+
assert result[-1]["content"].startswith("third-")
156+
assert not any(m["content"].startswith("first-") for m in result)
157+
158+
def test_empty_messages(self):
159+
assert trim_messages_to_token_limit([]) == []
160+
161+
@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", MAX_CONTENT_LENGTH)
162+
def test_default_limit_is_derived_from_max_content_length(self):
163+
"""Token budget should be MAX_CONTENT_LENGTH // CHARS_PER_TOKEN."""
164+
expected_token_budget = MAX_CONTENT_LENGTH // CHARS_PER_TOKEN
165+
# Create messages just under the budget -> no trimming
166+
msg_chars = (expected_token_budget - 1) * CHARS_PER_TOKEN
167+
msgs = [_make_msg("user", "a" * msg_chars)]
168+
result = trim_messages_to_token_limit(msgs)
169+
assert len(result) == 1
170+
171+
@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", MAX_CONTENT_LENGTH)
172+
def test_result_fits_within_token_budget(self):
173+
"""After trimming, estimated tokens must be <= budget."""
174+
token_budget = MAX_CONTENT_LENGTH // CHARS_PER_TOKEN
175+
# 20 messages each ~2500 tokens = 50000 tokens, well over 25000 budget
176+
msgs = [_make_msg("user" if i % 2 == 0 else "assistant", "x" * 10000)
177+
for i in range(20)]
178+
result = trim_messages_to_token_limit(msgs)
179+
result_tokens = sum(_estimate_tokens(m["content"]) for m in result)
180+
assert result_tokens <= token_budget
181+
182+
@patch("chatbot.agent_utils.Config.MAX_CONTENT_LENGTH", 400)
183+
def test_does_not_mutate_original(self):
184+
"""The original message list must not be modified."""
185+
msgs = [
186+
_make_msg("user", "a" * 200),
187+
_make_msg("assistant", "b" * 200),
188+
_make_msg("user", "c" * 200),
189+
]
190+
original_len = len(msgs)
191+
trim_messages_to_token_limit(msgs)
192+
assert len(msgs) == original_len

0 commit comments

Comments
 (0)