From 4a19bc4a188f3a8b1b3c83ca093cdf116c9a72f1 Mon Sep 17 00:00:00 2001 From: Andrej Simurka Date: Mon, 8 Jun 2026 08:53:11 +0200 Subject: [PATCH] Wired compaction turn persistence into agentic query flows --- src/app/endpoints/query.py | 2 +- src/app/endpoints/streaming_query.py | 14 +- src/models/common/turn_summary.py | 6 - .../llamastack/__init__.py | 11 +- .../llamastack/_model.py | 161 +++++++++++++++++- src/utils/agents/query.py | 10 +- src/utils/agents/streaming.py | 22 ++- src/utils/pydantic_ai.py | 15 ++ .../app/endpoints/test_streaming_query.py | 153 ----------------- tests/unit/utils/agents/test_streaming.py | 18 +- 10 files changed, 219 insertions(+), 193 deletions(-) diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 800b15374..2cdd5e326 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -233,7 +233,7 @@ async def query_endpoint_handler( responses_params, moderation_result, endpoint_path, - compaction.original_input if compaction.compacted else None, + compaction.original_input, ) if moderation_result.decision == "passed": diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index fc7d740c5..231197cad 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -338,6 +338,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals responses_params=responses_params, context=context, endpoint_path=endpoint_path, + original_input=None, ) # Combine inline RAG results (BYOK + Solr) with tool-based results @@ -353,6 +354,8 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals responses_params=responses_params, turn_summary=turn_summary, background_topic_summary_tasks=_background_topic_summary_tasks, + emit_start=True, + original_input=None, ), media_type=response_media_type, ) @@ -387,7 +390,6 @@ async def retrieve_response_generator( if context.moderation_result.decision == "blocked": turn_summary.llm_response = context.moderation_result.message turn_summary.id = context.moderation_result.moderation_id - turn_summary.output_items = [context.moderation_result.refusal_response] # In compacted mode the conversation parameter was omitted, so the # refusal turn (with the original input) is persisted by # generate_response; storing it here too would duplicate it. @@ -506,6 +508,7 @@ async def generate_response_with_compaction( responses_params=responses_params, context=context, endpoint_path=endpoint_path, + original_input=compacted_original_input, ) except HTTPException as e: yield http_exception_stream_event(e) @@ -705,7 +708,7 @@ async def generate_response( # pylint: disable=too-many-arguments,too-many-posi if original_input is not None else context.query_request.query ), - turn_summary.output_items, + [], # field was removed from TurnSummary ) except Exception: # pylint: disable=broad-except logger.exception( @@ -884,10 +887,6 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat getattr(chunk, "response"), # noqa: B009 ) turn_summary.llm_response = turn_summary.llm_response or "".join(text_parts) - # Capture structured output items for compacted-mode turn storage - # (LCORE-1572), so the persisted turn keeps non-text output items - # rather than being flattened to the response text. - turn_summary.output_items = list(latest_response_object.output or []) event_id = chunk_id chunk_id += 1 turn_summary.next_chunk_id = chunk_id @@ -906,9 +905,6 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat OpenAIResponseObject, getattr(chunk, "response"), # noqa: B009 ) - # Capture any partial output items so a compacted-mode turn is not - # persisted with empty output on these terminals (LCORE-1572). - turn_summary.output_items = list(latest_response_object.output or []) error_message = ( latest_response_object.error.message if latest_response_object.error diff --git a/src/models/common/turn_summary.py b/src/models/common/turn_summary.py index 2b342b758..6951f5ff1 100644 --- a/src/models/common/turn_summary.py +++ b/src/models/common/turn_summary.py @@ -5,7 +5,6 @@ from typing import Any, Optional -from llama_stack_api import OpenAIResponseOutput from pydantic import AnyUrl, BaseModel, Field from utils.token_counter import TokenCounter @@ -109,11 +108,6 @@ class TurnSummary(BaseModel): rag_chunks: list[RAGChunk] = Field(default_factory=list) referenced_documents: list[ReferencedDocument] = Field(default_factory=list) token_usage: TokenCounter = Field(default_factory=TokenCounter) - output_items: list[OpenAIResponseOutput] = Field( - default_factory=list, - description="Structured response output items, captured for compacted-mode " - "turn persistence (LCORE-1572). Empty on the non-compacted path.", - ) partial_tokens: list[str] = Field( default_factory=list, description="Accumulated text deltas during streaming, used to reconstruct " diff --git a/src/pydantic_ai_lightspeed/llamastack/__init__.py b/src/pydantic_ai_lightspeed/llamastack/__init__.py index fac9ee826..80cb95193 100644 --- a/src/pydantic_ai_lightspeed/llamastack/__init__.py +++ b/src/pydantic_ai_lightspeed/llamastack/__init__.py @@ -1,6 +1,13 @@ """Pydantic AI provider for Llama Stack.""" -from pydantic_ai_lightspeed.llamastack._model import LlamaStackResponsesModel +from pydantic_ai_lightspeed.llamastack._model import ( + CompactionTurnContext, + LlamaStackResponsesModel, +) from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider -__all__ = ["LlamaStackProvider", "LlamaStackResponsesModel"] +__all__ = [ + "CompactionTurnContext", + "LlamaStackProvider", + "LlamaStackResponsesModel", +] diff --git a/src/pydantic_ai_lightspeed/llamastack/_model.py b/src/pydantic_ai_lightspeed/llamastack/_model.py index 2338d241e..f026eb78e 100644 --- a/src/pydantic_ai_lightspeed/llamastack/_model.py +++ b/src/pydantic_ai_lightspeed/llamastack/_model.py @@ -10,6 +10,11 @@ deltas must be replayed with the matching suffix so pydantic_ai can append the streamed ``tool_args`` content to the correct part. +When compaction omits the ``conversation`` parameter from inference requests, +``LlamaStackResponsesModel`` appends completed turns to the conversation via +``CompactionTurnContext`` (in :meth:`_responses_create` for non-streaming rounds, +and on ``response.completed`` for streaming rounds in :meth:`request_stream`). + This module provides ``LlamaStackResponsesModel`` which wraps the event stream to buffer those early delta events and replay them correctly once the item is announced. """ @@ -17,16 +22,18 @@ from __future__ import annotations as _annotations from collections import defaultdict -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Sequence from contextlib import asynccontextmanager -from typing import Any, cast +from dataclasses import dataclass +from typing import Any, Literal, Optional, cast, overload +from llama_stack_client import AsyncLlamaStackClient from openai import AsyncStream from openai.types import responses from pydantic_ai import UnexpectedModelBehavior from pydantic_ai._run_context import RunContext from pydantic_ai._utils import PeekableAsyncStream, Unset, number_to_datetime -from pydantic_ai.messages import ModelMessage, ModelResponse +from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse from pydantic_ai.models import ( ModelRequestParameters, StreamedResponse, @@ -38,13 +45,38 @@ OpenAIResponsesStreamedResponse, _map_api_errors, ) +from pydantic_ai.profiles import ModelProfileSpec from pydantic_ai.settings import ModelSettings from log import get_logger +from models.common.responses.types import ResponseInput +from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider +from utils.conversations import append_turn_items_to_conversation logger = get_logger(__name__) +@dataclass +class CompactionTurnContext: + """Mutable state for manually persisting compacted agent turns. + + ``latest_round_input`` is initialized to the real user query. The create patch + leaves it unchanged on the first LLM round, then records pydantic-ai input + for follow-up rounds after that turn is persisted. + + Attributes: + client: Llama Stack client used to append conversation items. + conversation_id: Conversation to store turns against. + latest_round_input: Input stored for the current or next inference round. + original_input_persisted: Whether the first compacted round was appended. + """ + + client: AsyncLlamaStackClient + conversation_id: str + latest_round_input: ResponseInput + original_input_persisted: bool = False + + class _FilteredResponseStream: """Wraps an OpenAI AsyncStream to reorder spurious events from Llama Stack. @@ -58,13 +90,19 @@ class _FilteredResponseStream: a closing ``}`` to complete the outer JSON object that pydantic_ai opens. """ - def __init__(self, source: AsyncStream[responses.ResponseStreamEvent]) -> None: + def __init__( + self, + source: AsyncStream[responses.ResponseStreamEvent], + compaction: Optional[CompactionTurnContext] = None, + ) -> None: """Wrap an existing stream with reordering logic. Args: source: The raw OpenAI AsyncStream to reorder. + compaction: Compaction state for turn persistence, if active. """ self._source = source + self._compaction = compaction self._announced_item_ids: set[str] = set() self._buffered_deltas: dict[ str, list[responses.ResponseFunctionCallArgumentsDeltaEvent] @@ -112,6 +150,19 @@ async def _filtered_iter( self._buffered_deltas[event.item_id].append(event) continue + if ( + isinstance(event, responses.ResponseCompletedEvent) + and self._compaction is not None + ): + compaction = self._compaction + await append_turn_items_to_conversation( + compaction.client, + compaction.conversation_id, + compaction.latest_round_input, + cast(Sequence[Any], event.response.output), + ) + compaction.original_input_persisted = True + yield event def _replay_buffered_deltas( @@ -179,8 +230,108 @@ class LlamaStackResponsesModel(OpenAIResponsesModel): Overrides the streaming response processing to buffer and replay ``ResponseFunctionCallArgumentsDeltaEvent`` events that Llama Stack emits before the corresponding ``McpCall`` or ``ResponseFunctionToolCall`` item. + + When ``compaction`` is set, completed inference rounds are appended to the + conversation because compacted mode omits the ``conversation`` parameter. """ + def __init__( # pylint: disable=too-many-arguments + self, + model_name: str, + provider: LlamaStackProvider, + profile: ModelProfileSpec | None = None, + settings: ModelSettings | None = None, + compaction: Optional[CompactionTurnContext] = None, + ) -> None: + """Initialize the model. + + Args: + model_name: Model identifier passed to pydantic-ai. + provider: Pydantic AI provider or provider name. + profile: Optional model profile override. + settings: Optional pydantic-ai model settings. + compaction: Compaction state when turns must be stored manually. + """ + super().__init__( + model_name, + provider=provider, + profile=profile, + settings=settings, + ) + self.compaction = compaction + + @overload + async def _responses_create( + self, + messages: list[ModelRequest | ModelResponse], + stream: Literal[False], + model_settings: OpenAIResponsesModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> responses.Response: ... + + @overload + async def _responses_create( + self, + messages: list[ModelRequest | ModelResponse], + stream: Literal[True], + model_settings: OpenAIResponsesModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> AsyncStream[responses.ResponseStreamEvent]: ... + + async def _responses_create( + self, + messages: list[ModelRequest | ModelResponse], + stream: bool, + model_settings: OpenAIResponsesModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ( + responses.Response | AsyncStream[responses.ResponseStreamEvent] | ModelResponse + ): + """Create a Responses API request with compacted turn persistence. + + After the first compacted round is persisted, records pydantic-ai input + for follow-up tool-loop rounds. Non-streaming responses are appended + immediately; streaming persistence is handled in :meth:`request_stream`. + """ + compaction = self.compaction + if compaction is not None and compaction.original_input_persisted: + request_params = await self._build_responses_request_params( + messages, + model_settings, + model_request_parameters, + self.profile, + ) + compaction.latest_round_input = cast(ResponseInput, request_params.input) + + result: ( + responses.Response + | AsyncStream[responses.ResponseStreamEvent] + | ModelResponse + ) + if stream: + result = await super()._responses_create( + messages, True, model_settings, model_request_parameters + ) + else: + result = await super()._responses_create( + messages, False, model_settings, model_request_parameters + ) + + if ( + compaction is not None + and not stream + and isinstance(result, responses.Response) + ): + await append_turn_items_to_conversation( + compaction.client, + compaction.conversation_id, + compaction.latest_round_input, + cast(Sequence[Any], result.output), + ) + compaction.original_input_persisted = True + + return result + async def request( # pylint: disable=unused-argument self, messages: list[ModelMessage], @@ -274,7 +425,7 @@ async def request_stream( # pylint: disable=unused-argument messages, True, model_settings_cast, model_request_parameters ) - filtered_stream = _FilteredResponseStream(response) + filtered_stream = _FilteredResponseStream(response, self.compaction) async with response: peekable: PeekableAsyncStream[ diff --git a/src/utils/agents/query.py b/src/utils/agents/query.py index c0f5ad958..4fe7f305b 100644 --- a/src/utils/agents/query.py +++ b/src/utils/agents/query.py @@ -282,7 +282,7 @@ async def retrieve_agent_response( responses_params: ResponsesApiParams, moderation_result: ShieldModerationResult, endpoint_path: str, - _original_input: Optional[ResponseInput] = None, + original_input: Optional[ResponseInput] = None, ) -> TurnSummary: """Retrieve a turn summary from a blocking agent run. @@ -293,7 +293,7 @@ async def retrieve_agent_response( responses_params: Prepared Responses API parameters. moderation_result: Shield moderation outcome for the turn. endpoint_path: Endpoint path used for metric labeling. - _original_input: Original user input before the explicit-input rewrite. + original_input: Original user input before the explicit-input rewrite. Returns: Turn summary for the completed agent run. @@ -305,7 +305,7 @@ async def retrieve_agent_response( await append_turn_items_to_conversation( client, responses_params.conversation, - responses_params.input, + original_input or responses_params.input, [moderation_result.refusal_response], ) return TurnSummary( @@ -313,7 +313,9 @@ async def retrieve_agent_response( llm_response=moderation_result.message, ) try: - agent = build_agent(client, responses_params, configuration.skills) + agent = build_agent( + client, responses_params, configuration.skills, original_input + ) logger.debug("Starting agent non-streaming response processing") run_result = await agent.run(cast(str, responses_params.input)) except (AgentRunError, APIStatusError, APIConnectionError, RuntimeError) as exc: diff --git a/src/utils/agents/streaming.py b/src/utils/agents/streaming.py index b6afcf8d7..cc1ed606d 100644 --- a/src/utils/agents/streaming.py +++ b/src/utils/agents/streaming.py @@ -85,6 +85,7 @@ async def retrieve_agent_response_generator( responses_params: ResponsesApiParams, context: ResponseGeneratorContext, endpoint_path: str, + original_input: Optional[ResponseInput] = None, ) -> tuple[AsyncIterator[str], TurnSummary]: """Return the SSE generator and mutable turn summary for an agent run. @@ -92,6 +93,9 @@ async def retrieve_agent_response_generator( responses_params: Prepared Responses API parameters. context: Streaming request context and moderation result. endpoint_path: Endpoint path used for metric labeling. + original_input: In compacted mode, the original user input before the + explicit-input rewrite. Used to persist the completed turn with its + structured input (preserving attachments); ``None`` otherwise. Returns: Tuple of SSE async iterator and mutable turn summary. @@ -101,14 +105,12 @@ async def retrieve_agent_response_generator( if context.moderation_result.decision == "blocked": turn_summary.llm_response = context.moderation_result.message turn_summary.id = context.moderation_result.moderation_id - turn_summary.output_items = [context.moderation_result.refusal_response] - if not responses_params.omit_conversation: - await append_turn_items_to_conversation( - context.client, - responses_params.conversation, - responses_params.input, - [context.moderation_result.refusal_response], - ) + await append_turn_items_to_conversation( + context.client, + responses_params.conversation, + original_input or responses_params.input, + [context.moderation_result.refusal_response], + ) media_type = context.query_request.media_type or MEDIA_TYPE_JSON return ( shield_violation_generator( @@ -118,7 +120,9 @@ async def retrieve_agent_response_generator( turn_summary, ) - agent = build_agent(context.client, responses_params, configuration.skills) + agent = build_agent( + context.client, responses_params, configuration.skills, original_input + ) return ( agent_response_generator( diff --git a/src/utils/pydantic_ai.py b/src/utils/pydantic_ai.py index b4c73df6b..e8d364611 100644 --- a/src/utils/pydantic_ai.py +++ b/src/utils/pydantic_ai.py @@ -11,9 +11,11 @@ from pydantic_ai.models.openai import OpenAIResponsesModelSettings from pydantic_ai_skills import SkillsCapability +from models.common.responses import ResponseInput from models.common.responses.responses_api_params import ResponsesApiParams from models.config import SkillsConfiguration from pydantic_ai_lightspeed.llamastack import ( + CompactionTurnContext, LlamaStackProvider, LlamaStackResponsesModel, ) @@ -117,6 +119,7 @@ def build_agent( client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient, responses_params: ResponsesApiParams, skills: Optional[SkillsConfiguration], + original_input: Optional[ResponseInput] = None, ) -> Agent[None, str]: """Build a Pydantic AI agent that mirrors ``responses_params`` on the Llama Stack backend. @@ -129,6 +132,7 @@ def build_agent( client: Initialized Llama Stack client from ``AsyncLlamaStackClientHolder().get_client()``. responses_params: Parameters produced by ``prepare_responses_params`` for this turn. skills: Agent skills configuration from LCS, or None when skills are disabled. + original_input: When set, enables compacted-turn persistence on the model. Returns: ``Agent`` configured for ``await agent.run(...)`` (or streaming) against the same @@ -137,10 +141,21 @@ def build_agent( provider = llama_stack_provider_from_client(client) settings = _model_settings_from_responses_params(responses_params) + compaction = ( + CompactionTurnContext( + client=client, + conversation_id=responses_params.conversation, + latest_round_input=original_input, + ) + if original_input is not None + else None + ) + model = LlamaStackResponsesModel( responses_params.model, provider=provider, settings=settings, + compaction=compaction, ) return Agent( model, diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index fb2d9027b..5f6f7b332 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -709,55 +709,9 @@ async def test_retrieve_response_generator_shield_blocked( assert isinstance(turn_summary, TurnSummary) assert turn_summary.llm_response == "Content blocked" - # Structured refusal captured for compacted-mode persistence (LCORE-1572). - assert turn_summary.output_items == [mock_moderation_result.refusal_response] # Non-compacted: the refusal turn is stored here. mock_append.assert_awaited_once() - @pytest.mark.asyncio - async def test_retrieve_response_generator_shield_blocked_compacted( - self, mocker: MockerFixture - ) -> None: - """In compacted mode the shield refusal is not stored here (no double-store). - - generate_response persists the compacted turn (with the original input), - so storing it again in the shield branch would duplicate it (LCORE-1572). - """ - mock_client = mocker.AsyncMock(spec=AsyncLlamaStackClient) - - mock_responses_params = mocker.Mock(spec=ResponsesApiParams) - mock_responses_params.model = "provider1/model1" - mock_responses_params.input = "explicit input" - mock_responses_params.conversation = "conv_123" - mock_responses_params.omit_conversation = True # compacted - - mock_context = mocker.Mock(spec=ResponseGeneratorContext) - mock_context.client = mock_client - mock_context.vector_store_ids = [] - mock_context.rag_id_mapping = {} - mock_context.inline_rag_context = RAGContext() - mock_context.query_request = QueryRequest( - query="test", media_type=MEDIA_TYPE_TEXT - ) # pyright: ignore[reportCallIssue] - - mock_moderation_result = mocker.Mock() - mock_moderation_result.decision = "blocked" - mock_moderation_result.message = "Content blocked" - mock_moderation_result.moderation_id = "mod_123" - mock_moderation_result.refusal_response = mocker.Mock() - mock_context.moderation_result = mock_moderation_result - mock_append = mocker.patch( - "app.endpoints.streaming_query.append_turn_items_to_conversation", - new=mocker.AsyncMock(), - ) - - _generator, turn_summary = await retrieve_response_generator( - mock_responses_params, mock_context, endpoint_path="" - ) - - assert turn_summary.output_items == [mock_moderation_result.refusal_response] - mock_append.assert_not_awaited() # compacted: generate_response stores it - @pytest.mark.asyncio async def test_retrieve_response_generator_connection_error( self, mocker: MockerFixture @@ -1022,71 +976,6 @@ async def mock_generator() -> AsyncIterator[str]: assert any("start" in item for item in result) assert any("end" in item for item in result) - @pytest.mark.asyncio - async def test_generate_response_compacted_persists_structured_turn( - self, mocker: MockerFixture - ) -> None: - """Compacted mode persists the turn via store_compacted_turn with the - original input and structured output items, not flattened strings - (LCORE-1572).""" - - async def mock_generator() -> AsyncIterator[str]: - yield "data: token\n\n" - - conv_id = "123e4567-e89b-12d3-a456-426614174000" - mock_context = mocker.Mock(spec=ResponseGeneratorContext) - mock_context.conversation_id = conv_id - mock_context.user_id = "user_123" - mock_context.query_request = QueryRequest( - query="test", conversation_id=conv_id - ) # pyright: ignore[reportCallIssue] - mock_context.started_at = "2024-01-01T00:00:00Z" - mock_context.skip_userid_check = False - mock_context.request_id = "223e4567-e89b-12d3-a456-426614174000" - mock_context.client = mocker.AsyncMock(spec=AsyncLlamaStackClient) - - mock_responses_params = mocker.Mock(spec=ResponsesApiParams) - mock_responses_params.model = "provider1/model1" - mock_responses_params.conversation = conv_id - - turn_summary = TurnSummary() - turn_summary.token_usage = TokenCounter(input_tokens=10, output_tokens=5) - output_item = mocker.Mock() - turn_summary.output_items = [output_item] - - mock_config = mocker.Mock() - mock_config.quota_limiters = [] - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - mocker.patch("app.endpoints.streaming_query.consume_query_tokens") - mocker.patch( - "app.endpoints.streaming_query.get_available_quotas", return_value={} - ) - mocker.patch("app.endpoints.streaming_query.store_query_results") - store_mock = mocker.patch( - "app.endpoints.streaming_query.store_compacted_turn", - new_callable=mocker.AsyncMock, - ) - - result = [ - item - async for item in generate_response( - mock_generator(), - mock_context, - mock_responses_params, - turn_summary, - compacted=True, - original_input="the original query", - ) - ] - - assert any("end" in item for item in result) - store_mock.assert_awaited_once_with( - mock_context.client, - conv_id, - "the original query", - [output_item], - ) - @pytest.mark.asyncio async def test_generate_response_with_topic_summary( self, mocker: MockerFixture @@ -2663,45 +2552,3 @@ async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: # Should have both tool call and result (fallback behavior) assert len(mock_turn_summary.tool_calls) == 1 assert len(mock_turn_summary.tool_results) == 1 - - -@pytest.mark.asyncio -async def test_response_generator_failed_captures_output_items( - mocker: MockerFixture, -) -> None: - """A failed terminal captures output_items for compacted persistence (LCORE-1572).""" - out_item = mocker.Mock() - - async def mock_turn_response() -> AsyncIterator[OpenAIResponseObjectStream]: - chunk = mocker.Mock(spec=FailedChunk) - chunk.type = "response.failed" - mock_response = mocker.Mock() - mock_response.output = [out_item] - mock_response.error = mocker.Mock(message="boom") - chunk.response = mock_response - yield chunk - - mock_context = mocker.Mock(spec=ResponseGeneratorContext) - mock_context.query_request = QueryRequest( - query="test", media_type=MEDIA_TYPE_JSON - ) # pyright: ignore[reportCallIssue] - mock_context.model_id = "provider1/model1" - mock_context.vector_store_ids = [] - mock_context.rag_id_mapping = {} - mock_context.inline_rag_context = RAGContext() - - turn_summary = TurnSummary() - mocker.patch( - "app.endpoints.streaming_query.extract_token_usage", - return_value=TokenCounter(input_tokens=0, output_tokens=0), - ) - mocker.patch( - "app.endpoints.streaming_query.parse_referenced_documents", return_value=[] - ) - - async for _ in response_generator( - mock_turn_response(), mock_context, turn_summary, endpoint_path="" - ): - pass - - assert turn_summary.output_items == [out_item] diff --git a/tests/unit/utils/agents/test_streaming.py b/tests/unit/utils/agents/test_streaming.py index 6846ebdc4..fe9e42d09 100644 --- a/tests/unit/utils/agents/test_streaming.py +++ b/tests/unit/utils/agents/test_streaming.py @@ -530,16 +530,20 @@ async def test_blocked_moderation_returns_shield_generator( assert turn_summary.id == blocked_moderation.moderation_id @pytest.mark.asyncio - async def test_blocked_moderation_skips_append_when_omit_conversation( + async def test_blocked_moderation_compacted_appends_with_original_input( self, mocker: MockerFixture, make_generator_context: Callable[..., ResponseGeneratorContext], make_responses_params: Callable[..., ResponsesApiParams], blocked_moderation: ShieldModerationBlocked, ) -> None: - """Test compacted mode does not append blocked turn to conversation.""" + """Test compacted blocked turns persist with original_input, not explicit input.""" context = make_generator_context(moderation_result=blocked_moderation) - responses_params = make_responses_params(omit_conversation=True) + responses_params = make_responses_params( + omit_conversation=True, + input_text="explicit summaries-plus-recent input", + ) + original_input = "the real user query" mocker.patch( "utils.agents.streaming.shield_violation_generator", return_value=_async_iter([]), @@ -553,9 +557,15 @@ async def test_blocked_moderation_skips_append_when_omit_conversation( responses_params, context, ENDPOINT_PATH_STREAMING_QUERY, + original_input=original_input, ) - mock_append.assert_not_awaited() + mock_append.assert_awaited_once_with( + context.client, + responses_params.conversation, + original_input, + [blocked_moderation.refusal_response], + ) @pytest.mark.asyncio async def test_success_returns_agent_generator(