Skip to content

Commit 4d65315

Browse files
author
Petr Marinec
committed
Sandbox nested persona template rendering
1 parent 9d4ecbe commit 4d65315

4 files changed

Lines changed: 111 additions & 6 deletions

File tree

src/google/adk/evaluation/simulation/llm_backed_user_simulator_prompts.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def get_llm_backed_user_simulator_prompt(
185185
"""Formats the prompt for the llm-backed user simulator"""
186186
from jinja2 import DictLoader
187187
from jinja2 import pass_context
188-
from jinja2 import Template
189188
from jinja2.sandbox import SandboxedEnvironment
190189

191190
templates = {
@@ -200,7 +199,7 @@ def get_llm_backed_user_simulator_prompt(
200199
def _render_string_filter(context, template_string):
201200
if not template_string:
202201
return ""
203-
return Template(template_string).render(context)
202+
return template_env.from_string(template_string).render(context.get_all())
204203

205204
template_env.filters["render_string_filter"] = _render_string_filter
206205

src/google/adk/evaluation/simulation/per_turn_user_simulator_quality_prompts.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,8 @@ def get_per_turn_user_simulator_quality_prompt(
221221
):
222222
"""Formats the prompt for the per turn user simulator evaluator"""
223223
from jinja2 import DictLoader
224-
from jinja2 import Environment
225224
from jinja2 import pass_context
226-
from jinja2 import Template
225+
from jinja2.sandbox import SandboxedEnvironment
227226

228227
templates = {
229228
"verifier_instructions": (
@@ -232,13 +231,13 @@ def get_per_turn_user_simulator_quality_prompt(
232231
)
233232
),
234233
}
235-
template_env = Environment(loader=DictLoader(templates))
234+
template_env = SandboxedEnvironment(loader=DictLoader(templates))
236235

237236
@pass_context
238237
def _render_string_filter(context, template_string):
239238
if not template_string:
240239
return ""
241-
return Template(template_string).render(context)
240+
return template_env.from_string(template_string).render(context.get_all())
242241

243242
template_env.filters["render_string_filter"] = _render_string_filter
244243

tests/unittests/evaluation/simulation/test_llm_backed_user_simulator_prompts.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google.adk.evaluation.simulation.llm_backed_user_simulator_prompts import is_valid_user_simulator_template
2222
from google.adk.evaluation.simulation.user_simulator_personas import UserBehavior
2323
from google.adk.evaluation.simulation.user_simulator_personas import UserPersona
24+
from jinja2.exceptions import SecurityError
2425
import pytest
2526

2627
_MOCK_DEFAULT_TEMPLATE = textwrap.dedent("""\
@@ -208,6 +209,57 @@ def test_get_llm_backed_user_simulator_prompt_with_persona(self, mocker):
208209
test stop""").strip()
209210
assert prompt == expected_prompt
210211

212+
def test_get_llm_backed_user_simulator_prompt_renders_persona_templates_in_sandbox(
213+
self,
214+
):
215+
user_persona = UserPersona(
216+
id="test_persona",
217+
description="Test persona description",
218+
behaviors=[
219+
UserBehavior(
220+
name="Behavior {{ stop_signal }}",
221+
description="Description {{ stop_signal }}",
222+
behavior_instructions=["instruction {{ stop_signal }}"],
223+
violation_rubrics=["rubric 1"],
224+
)
225+
],
226+
)
227+
228+
prompt = get_llm_backed_user_simulator_prompt(
229+
conversation_plan="test plan",
230+
conversation_history="test history",
231+
stop_signal="test stop",
232+
user_persona=user_persona,
233+
)
234+
235+
assert "## Behavior test stop" in prompt
236+
assert "Description test stop" in prompt
237+
assert " * instruction test stop" in prompt
238+
239+
def test_get_llm_backed_user_simulator_prompt_blocks_unsafe_persona_templates(
240+
self,
241+
):
242+
user_persona = UserPersona(
243+
id="test_persona",
244+
description="Test persona description",
245+
behaviors=[
246+
UserBehavior(
247+
name="{{ ''.__class__.__mro__ }}",
248+
description="Test behavior description",
249+
behavior_instructions=["instruction 1"],
250+
violation_rubrics=["rubric 1"],
251+
)
252+
],
253+
)
254+
255+
with pytest.raises(SecurityError):
256+
get_llm_backed_user_simulator_prompt(
257+
conversation_plan="test plan",
258+
conversation_history="test history",
259+
stop_signal="test stop",
260+
user_persona=user_persona,
261+
)
262+
211263

212264
class TestIsValidUserSimulatorTemplate:
213265
"""Test cases for is_valid_user_simulator_template."""

tests/unittests/evaluation/simulation/test_per_turn_user_simulation_quality_prompts.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from google.adk.evaluation.simulation.per_turn_user_simulator_quality_prompts import get_per_turn_user_simulator_quality_prompt
2121
from google.adk.evaluation.simulation.user_simulator_personas import UserBehavior
2222
from google.adk.evaluation.simulation.user_simulator_personas import UserPersona
23+
from jinja2.exceptions import SecurityError
24+
import pytest
2325

2426
_MOCK_DEFAULT_TEMPLATE = textwrap.dedent("""\
2527
Default template
@@ -182,3 +184,56 @@ def test_get_per_turn_user_simulator_quality_prompt_with_persona(
182184
# Stop signal
183185
stop""").strip()
184186
assert prompt == expected_prompt
187+
188+
def test_get_per_turn_user_simulator_quality_prompt_renders_persona_templates_in_sandbox(
189+
self,
190+
):
191+
persona = UserPersona(
192+
id="test_persona",
193+
description="Test persona description.",
194+
behaviors=[
195+
UserBehavior(
196+
name="criteria {{ stop_signal }}",
197+
description="Test behavior {{ stop_signal }}.",
198+
behavior_instructions=["instruction1"],
199+
violation_rubrics=["violation {{ stop_signal }}"],
200+
)
201+
],
202+
)
203+
204+
prompt = get_per_turn_user_simulator_quality_prompt(
205+
conversation_plan="plan",
206+
conversation_history="history",
207+
generated_user_response="response",
208+
stop_signal="stop",
209+
user_persona=persona,
210+
)
211+
212+
assert "## Criteria: criteria stop" in prompt
213+
assert "Test behavior stop." in prompt
214+
assert " * violation stop" in prompt
215+
216+
def test_get_per_turn_user_simulator_quality_prompt_blocks_unsafe_persona_templates(
217+
self,
218+
):
219+
persona = UserPersona(
220+
id="test_persona",
221+
description="Test persona description.",
222+
behaviors=[
223+
UserBehavior(
224+
name="{{ ''.__class__.__mro__ }}",
225+
description="Test behavior description.",
226+
behavior_instructions=["instruction1"],
227+
violation_rubrics=["violation1"],
228+
)
229+
],
230+
)
231+
232+
with pytest.raises(SecurityError):
233+
get_per_turn_user_simulator_quality_prompt(
234+
conversation_plan="plan",
235+
conversation_history="history",
236+
generated_user_response="response",
237+
stop_signal="stop",
238+
user_persona=persona,
239+
)

0 commit comments

Comments
 (0)