Skip to content

Commit 39c96da

Browse files
authored
Add files via upload
1 parent 69cd893 commit 39c96da

1 file changed

Lines changed: 225 additions & 0 deletions

File tree

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
"""
2+
GenericAgent with training data saving functionality.
3+
4+
This module extends GenericAgent to save training data (system prompt, user prompt, and agent output)
5+
for each step during benchmarking. This is useful for creating training datasets.
6+
"""
7+
8+
from copy import deepcopy
9+
from dataclasses import asdict
10+
from pathlib import Path
11+
import json
12+
import logging
13+
14+
from browsergym.experiments.agent import AgentInfo
15+
16+
from agentlab.agents import dynamic_prompting as dp
17+
from agentlab.llm.llm_utils import Discussion, ParseError, SystemMessage, BaseMessage, retry
18+
from agentlab.llm.tracking import cost_tracker_decorator
19+
20+
from .generic_agent import GenericAgent, GenericAgentArgs
21+
from .generic_agent_prompt import MainPrompt
22+
23+
logger = logging.getLogger(__name__)
24+
25+
26+
class GenericAgentWithTrainingArgs(GenericAgentArgs):
27+
"""Agent arguments for GenericAgentWithTraining."""
28+
29+
def __post_init__(self):
30+
super().__post_init__()
31+
try:
32+
self.agent_name = f"GenericAgentWithTraining-{self.chat_model_args.model_name}".replace("/", "_")
33+
except AttributeError:
34+
pass
35+
36+
def make_agent(self):
37+
return GenericAgentWithTraining(
38+
chat_model_args=self.chat_model_args, flags=self.flags, max_retry=self.max_retry
39+
)
40+
41+
42+
class GenericAgentWithTraining(GenericAgent):
43+
"""
44+
GenericAgent extended with training data saving functionality.
45+
46+
This agent saves:
47+
1. System prompt (separately)
48+
2. User prompt (with all context: observations including environment memory, history, etc.)
49+
3. Agent output (thinking/reasoning + action)
50+
51+
All saved in the training_data/ directory within the experiment directory.
52+
"""
53+
54+
def __init__(
55+
self,
56+
chat_model_args,
57+
flags,
58+
max_retry: int = 4,
59+
):
60+
super().__init__(chat_model_args, flags, max_retry)
61+
62+
# Track current step for saving training data
63+
self._current_step = None
64+
self._training_data_dir = None
65+
66+
@cost_tracker_decorator
67+
def get_action(self, obs):
68+
"""Override get_action to save training data before and after LLM call."""
69+
70+
self.obs_history.append(obs)
71+
main_prompt = MainPrompt(
72+
action_set=self.action_set,
73+
obs_history=self.obs_history,
74+
actions=self.actions,
75+
memories=self.memories,
76+
thoughts=self.thoughts,
77+
previous_plan=self.plan,
78+
step=self.plan_step,
79+
flags=self.flags,
80+
)
81+
82+
max_prompt_tokens, max_trunc_itr = self._get_maxes()
83+
84+
system_prompt = SystemMessage(dp.SystemPrompt().prompt)
85+
86+
human_prompt = dp.fit_tokens(
87+
shrinkable=main_prompt,
88+
max_prompt_tokens=max_prompt_tokens,
89+
model_name=self.chat_model_args.model_name,
90+
max_iterations=max_trunc_itr,
91+
additional_prompts=system_prompt,
92+
)
93+
94+
# Save system prompt and user prompt right before LLM call
95+
if self._training_data_dir is not None and self._current_step is not None:
96+
self._save_training_prompts(system_prompt, human_prompt, self._current_step)
97+
else:
98+
logger.debug(f"Not saving training prompts: training_data_dir={self._training_data_dir}, current_step={self._current_step}")
99+
100+
try:
101+
# TODO, we would need to further shrink the prompt if the retry
102+
# cause it to be too long
103+
104+
chat_messages = Discussion([system_prompt, human_prompt])
105+
ans_dict = retry(
106+
self.chat_llm,
107+
chat_messages,
108+
n_retry=self.max_retry,
109+
parser=main_prompt._parse_answer,
110+
)
111+
ans_dict["busted_retry"] = 0
112+
# inferring the number of retries, TODO: make this less hacky
113+
ans_dict["n_retry"] = (len(chat_messages) - 3) / 2
114+
except ParseError as e:
115+
ans_dict = dict(
116+
action=None,
117+
n_retry=self.max_retry + 1,
118+
busted_retry=1,
119+
)
120+
121+
stats = self.chat_llm.get_stats()
122+
stats["n_retry"] = ans_dict["n_retry"]
123+
stats["busted_retry"] = ans_dict["busted_retry"]
124+
125+
self.plan = ans_dict.get("plan", self.plan)
126+
self.plan_step = ans_dict.get("step", self.plan_step)
127+
action = ans_dict["action"]
128+
self.actions.append(action)
129+
self.memories.append(ans_dict.get("memory", None))
130+
self.thoughts.append(ans_dict.get("think", None))
131+
132+
# Save the agent output (thinking + action)
133+
if self._training_data_dir is not None and self._current_step is not None:
134+
self._save_training_output(ans_dict, self._current_step)
135+
else:
136+
logger.debug(f"Not saving training output: training_data_dir={self._training_data_dir}, current_step={self._current_step}")
137+
138+
agent_info = AgentInfo(
139+
think=ans_dict.get("think", None),
140+
chat_messages=chat_messages,
141+
stats=stats,
142+
extra_info={"chat_model_args": asdict(self.chat_model_args)},
143+
)
144+
return action, agent_info
145+
146+
def _save_training_prompts(self, system_prompt: SystemMessage, human_prompt: BaseMessage, step: int):
147+
"""Save system prompt and user prompt separately for training data."""
148+
try:
149+
logger.info(f"Saving training prompts for step {step} to {self._training_data_dir}")
150+
training_dir = Path(self._training_data_dir)
151+
training_dir.mkdir(parents=True, exist_ok=True)
152+
153+
# Save system prompt
154+
system_dict = {
155+
"role": system_prompt.get("role", "system"),
156+
"content": deepcopy(system_prompt.get("content", ""))
157+
}
158+
159+
system_file = training_dir / f"system_prompt_step_{step}.json"
160+
with open(system_file, "w", encoding="utf-8") as f:
161+
json.dump(system_dict, f, indent=2, ensure_ascii=False)
162+
163+
system_text_file = training_dir / f"system_prompt_step_{step}.txt"
164+
with open(system_text_file, "w", encoding="utf-8") as f:
165+
f.write(str(system_prompt))
166+
167+
# Save user prompt (human prompt without system)
168+
user_dict = {
169+
"role": human_prompt.get("role", "user"),
170+
"content": deepcopy(human_prompt.get("content", ""))
171+
}
172+
173+
user_file = training_dir / f"user_prompt_step_{step}.json"
174+
with open(user_file, "w", encoding="utf-8") as f:
175+
json.dump(user_dict, f, indent=2, ensure_ascii=False)
176+
177+
user_text_file = training_dir / f"user_prompt_step_{step}.txt"
178+
with open(user_text_file, "w", encoding="utf-8") as f:
179+
f.write(str(human_prompt))
180+
181+
except Exception as e:
182+
logger.warning(f"Failed to save training prompts for step {step}: {e}")
183+
184+
def _save_training_output(self, ans_dict: dict, step: int):
185+
"""Save the agent output (thinking/reasoning + action) for training data."""
186+
try:
187+
logger.info(f"Saving training output for step {step} to {self._training_data_dir}")
188+
training_dir = Path(self._training_data_dir)
189+
training_dir.mkdir(parents=True, exist_ok=True)
190+
191+
# Extract only thinking and action from ans_dict
192+
output_dict = {
193+
"think": ans_dict.get("think", None),
194+
"action": ans_dict.get("action", None),
195+
}
196+
197+
# Save as JSON
198+
output_file = training_dir / f"agent_output_step_{step}.json"
199+
with open(output_file, "w", encoding="utf-8") as f:
200+
json.dump(output_dict, f, indent=2, ensure_ascii=False)
201+
202+
# Also save a text version with thinking and action
203+
output_text_parts = []
204+
if output_dict["think"]:
205+
output_text_parts.append(f"Thinking:\n{output_dict['think']}\n")
206+
if output_dict["action"]:
207+
output_text_parts.append(f"Action:\n{output_dict['action']}\n")
208+
209+
output_text_file = training_dir / f"agent_output_step_{step}.txt"
210+
with open(output_text_file, "w", encoding="utf-8") as f:
211+
f.write("\n".join(output_text_parts) if output_text_parts else "")
212+
213+
except Exception as e:
214+
logger.warning(f"Failed to save training output for step {step}: {e}")
215+
216+
def set_training_data_dir(self, exp_dir: Path):
217+
"""Set the directory where training data should be saved."""
218+
self._training_data_dir = Path(exp_dir) / "training_data"
219+
self._training_data_dir.mkdir(parents=True, exist_ok=True)
220+
logger.info(f"Training data directory set to: {self._training_data_dir}")
221+
222+
def set_current_step(self, step: int):
223+
"""Set the current step number for saving training data."""
224+
self._current_step = step
225+

0 commit comments

Comments
 (0)