diff --git a/src/pydantic_ai_lightspeed/capabilities/question_validity/_capability.py b/src/pydantic_ai_lightspeed/capabilities/question_validity/_capability.py index 8a09d273b..a5b28fa97 100644 --- a/src/pydantic_ai_lightspeed/capabilities/question_validity/_capability.py +++ b/src/pydantic_ai_lightspeed/capabilities/question_validity/_capability.py @@ -27,7 +27,6 @@ QuestionValidityConfig, ) from pydantic_ai_lightspeed.llamastack import LlamaStackResponsesModel -from utils.pydantic_ai import llama_stack_provider_from_client logger = get_logger(__name__) @@ -55,21 +54,6 @@ def _extract_message_str_from_user_content(user_content: Sequence[UserContent]) return "\n".join(str_arr) -def _create_model_from_llama_stack_client(model_id: str) -> LlamaStackResponsesModel: - """Create a LlamaStackResponsesModel from the shared Llama Stack client. - - Parameters: - model_id: The model identifier to use for the responses model. - - Returns: - A configured LlamaStackResponsesModel instance. - """ - client = AsyncLlamaStackClientHolder().get_client() - provider = llama_stack_provider_from_client(client) - settings = OpenAIResponsesModelSettings(openai_store=False) - return LlamaStackResponsesModel(model_id, provider=provider, settings=settings) - - @dataclass class QuestionValidity(AbstractCapability[None]): """Block or modify user input based on a guardrail check. @@ -91,7 +75,13 @@ class QuestionValidity(AbstractCapability[None]): def __post_init__(self) -> None: """Initialize the model instance from the configured model ID.""" - self._model = _create_model_from_llama_stack_client(self.config.model_id) + llama_stack_client = AsyncLlamaStackClientHolder().get_client() + + self._model = LlamaStackResponsesModel.from_llama_stack_client( + self.config.model_id, + llama_stack_client, + model_settings=OpenAIResponsesModelSettings(openai_store=False), + ) def _build_prompt(self, message: str | Sequence[UserContent] | None) -> str: """Build the classification prompt from the user message. diff --git a/src/pydantic_ai_lightspeed/llamastack/_model.py b/src/pydantic_ai_lightspeed/llamastack/_model.py index 2338d241e..99a80902b 100644 --- a/src/pydantic_ai_lightspeed/llamastack/_model.py +++ b/src/pydantic_ai_lightspeed/llamastack/_model.py @@ -19,8 +19,10 @@ from collections import defaultdict from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any, cast +from typing import Any, Final, cast +from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient +from llama_stack_client import AsyncLlamaStackClient from openai import AsyncStream from openai.types import responses from pydantic_ai import UnexpectedModelBehavior @@ -38,12 +40,56 @@ OpenAIResponsesStreamedResponse, _map_api_errors, ) +from pydantic_ai.profiles import ModelProfileSpec from pydantic_ai.settings import ModelSettings from log import get_logger +from models.common.responses.responses_api_params import ResponsesApiParams +from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider logger = get_logger(__name__) +_LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset( + { + "conversation", + "max_infer_iters", + "tools", + "tool_choice", + "include", + "text", + "reasoning", + "prompt", + "metadata", + "max_tool_calls", + "safety_identifier", + } +) + + +def _model_settings_from_responses_params( + responses_params: ResponsesApiParams, +) -> OpenAIResponsesModelSettings: + """Map ``ResponsesApiParams`` into Pydantic AI OpenAI Responses model settings.""" + payload = responses_params.model_dump(exclude_none=True) + extra_body = {k: v for k, v in payload.items() if k in _LLS_RESPONSES_EXTRA_FIELDS} + settings_dict: dict[str, Any] = {} + if extra_body: + settings_dict["extra_body"] = extra_body + if responses_params.max_output_tokens is not None: + settings_dict["max_tokens"] = responses_params.max_output_tokens + if responses_params.temperature is not None: + settings_dict["temperature"] = responses_params.temperature + if responses_params.parallel_tool_calls is not None: + settings_dict["parallel_tool_calls"] = responses_params.parallel_tool_calls + if responses_params.extra_headers: + settings_dict["extra_headers"] = dict(responses_params.extra_headers) + settings_dict["openai_store"] = responses_params.store + if responses_params.previous_response_id is not None: + settings_dict["openai_previous_response_id"] = ( + responses_params.previous_response_id + ) + return cast(OpenAIResponsesModelSettings, settings_dict) + class _FilteredResponseStream: """Wraps an OpenAI AsyncStream to reorder spurious events from Llama Stack. @@ -307,3 +353,53 @@ async def request_stream( # pylint: disable=unused-argument else None ), ) + + @staticmethod + def from_llama_stack_client( + model_name: str, + client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient, + *, + responses_params: ResponsesApiParams | None = None, + model_settings: ModelSettings | None = None, + profile: ModelProfileSpec | None = None, + ) -> LlamaStackResponsesModel: + """Create a ``LlamaStackResponsesModel`` from a Llama Stack client. + + Mirrors ``OpenAIResponsesModel.__init__`` parameters, but accepts a + Llama Stack client instead of a provider. Exactly one of + ``responses_params`` or ``model_settings`` may be provided. + + Args: + model_name: The model name/ID to use. + client: Llama Stack client to build the provider from. + responses_params: Optional ``ResponsesApiParams``, converted to + ``OpenAIResponsesModelSettings`` internally. Mutually + exclusive with ``model_settings``. + model_settings: Optional raw ``ModelSettings`` passed through + directly. Mutually exclusive with ``responses_params``. + profile: Optional model profile specification. + + Raises: + ValueError: If both ``responses_params`` and ``model_settings`` + are provided. + + Returns: + Configured ``LlamaStackResponsesModel`` instance. + """ + provider = LlamaStackProvider.from_llama_stack_client(client) + + if responses_params is not None and model_settings is not None: + raise ValueError( + "You can only pass either ResponsesApiParams or ModelSetting not both." + ) + + _settings: OpenAIResponsesModelSettings | ModelSettings | None = None + + if responses_params is not None: + _settings = _model_settings_from_responses_params(responses_params) + elif model_settings is not None: + _settings = model_settings + + return LlamaStackResponsesModel( + model_name, provider=provider, profile=profile, settings=_settings + ) diff --git a/src/pydantic_ai_lightspeed/llamastack/_provider.py b/src/pydantic_ai_lightspeed/llamastack/_provider.py index dee7311fa..0e66f6d3c 100644 --- a/src/pydantic_ai_lightspeed/llamastack/_provider.py +++ b/src/pydantic_ai_lightspeed/llamastack/_provider.py @@ -5,6 +5,8 @@ from typing import TYPE_CHECKING, Optional import httpx +from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient +from llama_stack_client import AsyncLlamaStackClient from openai import AsyncOpenAI from pydantic_ai import ModelProfile from pydantic_ai.models import create_async_http_client @@ -14,7 +16,9 @@ from pydantic_ai_lightspeed.llamastack._transport import LlamaStackLibraryTransport if TYPE_CHECKING: - from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient + from llama_stack.core.library_client import ( # pylint: disable=reimported + AsyncLlamaStackAsLibraryClient, + ) DEFAULT_BASE_URL = "http://localhost:8321/v1" @@ -48,6 +52,33 @@ def model_profile(model_name: str) -> Optional[ModelProfile]: """Return the model profile for the named model, if available.""" return openai_model_profile(model_name) + @staticmethod + def from_llama_stack_client( + client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient, + ) -> LlamaStackProvider: + """Create a ``LlamaStackProvider`` from a Llama Stack client. + + For an ``AsyncLlamaStackAsLibraryClient``, delegates to library mode. + For an ``AsyncLlamaStackClient``, extracts the base URL, API key, and + underlying HTTP client to create a server-mode provider. + + Args: + client: A Llama Stack client (server or library variant). + + Returns: + Configured ``LlamaStackProvider`` instance. + """ + if isinstance(client, AsyncLlamaStackAsLibraryClient): + return LlamaStackProvider(library_client=client) + api_key = client.api_key or "not-needed" + base = str(client.base_url).rstrip("/") + base_url = base if base.endswith("/v1") else f"{base}/v1" + return LlamaStackProvider( + base_url=base_url, + api_key=api_key, + http_client=client._client, # pylint: disable=protected-access + ) + def __init__( self, *, diff --git a/src/utils/pydantic_ai.py b/src/utils/pydantic_ai.py index a56e72f58..38bf71e92 100644 --- a/src/utils/pydantic_ai.py +++ b/src/utils/pydantic_ai.py @@ -3,19 +3,17 @@ from __future__ import annotations import re -from typing import Any, Final, Optional, cast +from typing import Any, Final, Optional from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient from llama_stack_client import AsyncLlamaStackClient from pydantic_ai.agent import Agent from pydantic_ai.capabilities import AbstractCapability, AgentCapability -from pydantic_ai.models.openai import OpenAIResponsesModelSettings from pydantic_ai_skills import SkillsCapability from models.common.responses.responses_api_params import ResponsesApiParams from models.config import SkillsConfiguration from pydantic_ai_lightspeed.llamastack import ( - LlamaStackProvider, LlamaStackResponsesModel, ) @@ -24,64 +22,6 @@ _BUILTIN_CAPABILITY_SERVER_SOURCE: Final[str] = "builtin" _CAPABILITY_TOOL_TYPE: Final[str] = "tool" -_LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset( - { - "conversation", - "max_infer_iters", - "tool_choice", - "include", - "text", - "reasoning", - "prompt", - "metadata", - "max_tool_calls", - "safety_identifier", - } -) - - -def llama_stack_provider_from_client( - client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient, -) -> LlamaStackProvider: - """Construct a Pydantic AI Llama Stack provider backed by the same client as ``/query``.""" - if isinstance(client, AsyncLlamaStackAsLibraryClient): - return LlamaStackProvider(library_client=client) - api_key = client.api_key or "not-needed" - base = str(client.base_url).rstrip("/") - base_url = base if base.endswith("/v1") else f"{base}/v1" - return LlamaStackProvider( - base_url=base_url, - api_key=api_key, - http_client=client._client, # pylint: disable=protected-access - ) - - -def _model_settings_from_responses_params( - responses_params: ResponsesApiParams, -) -> OpenAIResponsesModelSettings: - """Map ``ResponsesApiParams`` into Pydantic AI OpenAI Responses model settings.""" - payload = responses_params.model_dump(exclude_none=True) - extra_body = {k: v for k, v in payload.items() if k in _LLS_RESPONSES_EXTRA_FIELDS} - settings_dict: dict[str, Any] = {} - if extra_body: - settings_dict["extra_body"] = extra_body - if responses_params.max_output_tokens is not None: - settings_dict["max_tokens"] = responses_params.max_output_tokens - if responses_params.temperature is not None: - settings_dict["temperature"] = responses_params.temperature - if responses_params.parallel_tool_calls is not None: - settings_dict["parallel_tool_calls"] = responses_params.parallel_tool_calls - if responses_params.extra_headers: - settings_dict["extra_headers"] = dict(responses_params.extra_headers) - settings_dict["openai_store"] = responses_params.store - if responses_params.tools is not None: - settings_dict["openai_native_tools"] = responses_params.tools - if responses_params.previous_response_id is not None: - settings_dict["openai_previous_response_id"] = ( - responses_params.previous_response_id - ) - return cast(OpenAIResponsesModelSettings, settings_dict) - def _skills_capability( skills_config: Optional[SkillsConfiguration], @@ -239,15 +179,12 @@ def build_agent( ``Agent`` configured for ``await agent.run(...)`` (or streaming) against the same stack configuration as ``client.responses.create(**responses_params.model_dump())``. """ - provider = llama_stack_provider_from_client(client) - settings = _model_settings_from_responses_params(responses_params) capabilities = _agent_capabilities(skills, no_tools=no_tools) - model = LlamaStackResponsesModel( - responses_params.model, - provider=provider, - settings=settings, + model = LlamaStackResponsesModel.from_llama_stack_client( + responses_params.model, client, responses_params=responses_params ) + return Agent( model, instructions=responses_params.instructions, diff --git a/tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capability.py b/tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capability.py index 93cc9bab7..2e6037fc9 100644 --- a/tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capability.py +++ b/tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capability.py @@ -6,6 +6,7 @@ from pydantic import ValidationError from pydantic_ai import AgentRunResult, RunContext from pydantic_ai.messages import ImageUrl, ModelResponse, TextContent, TextPart +from pydantic_ai.models.openai import OpenAIResponsesModelSettings from pydantic_ai.usage import RequestUsage, RunUsage from pytest_mock import MockerFixture, MockType @@ -20,10 +21,11 @@ SUBJECT_ALLOWED, SUBJECT_REJECTED, QuestionValidity, - _create_model_from_llama_stack_client, _extract_message_str_from_user_content, ) +_MODULE = "pydantic_ai_lightspeed.capabilities.question_validity._capability" + class TestExtractMessageStrFromUserContent: """Tests for _extract_message_str_from_user_content helper.""" @@ -104,72 +106,35 @@ def test_unknown_fields_rejected(self) -> None: QuestionValidityConfig(model_id="test", unknown_field="value") # type: ignore[call-arg] -class TestCreateModelFromLlamaStackClient: - """Tests for _create_model_from_llama_stack_client factory function.""" - - _MODULE = "pydantic_ai_lightspeed.capabilities.question_validity._capability" - - def test_creates_model_with_correct_wiring(self, mocker: MockerFixture) -> None: - """Test that the factory wires client, provider, and model correctly.""" - mock_client = mocker.Mock() - mock_holder = mocker.patch(f"{self._MODULE}.AsyncLlamaStackClientHolder") - mock_holder.return_value.get_client.return_value = mock_client - - mock_provider = mocker.Mock() - mocker.patch( - f"{self._MODULE}.llama_stack_provider_from_client", - return_value=mock_provider, - ) - - mock_model_cls = mocker.patch(f"{self._MODULE}.LlamaStackResponsesModel") - - result = _create_model_from_llama_stack_client("test-model") +class TestQuestionValidityInit: + """Tests for QuestionValidity dataclass initialization.""" - mock_holder.return_value.get_client.assert_called_once() - mock_model_cls.assert_called_once() - call_args = mock_model_cls.call_args - assert call_args.args[0] == "test-model" - assert call_args.kwargs["provider"] is mock_provider - settings = call_args.kwargs["settings"] - assert settings == {"openai_store": False} - assert result is mock_model_cls.return_value - - def test_passes_client_to_provider_factory(self, mocker: MockerFixture) -> None: - """Test that the client from the holder is passed to the provider factory.""" + def test_post_init_wires_client_and_model(self, mocker: MockerFixture) -> None: + """Test that __post_init__ obtains the client and passes it to from_llama_stack_client.""" mock_client = mocker.Mock() - mock_holder = mocker.patch(f"{self._MODULE}.AsyncLlamaStackClientHolder") + mock_holder = mocker.patch(f"{_MODULE}.AsyncLlamaStackClientHolder") mock_holder.return_value.get_client.return_value = mock_client mock_from_client = mocker.patch( - f"{self._MODULE}.llama_stack_provider_from_client", + f"{_MODULE}.LlamaStackResponsesModel.from_llama_stack_client", ) - _create_model_from_llama_stack_client("any-model") - - mock_from_client.assert_called_once_with(mock_client) - - -class TestQuestionValidityInit: - """Tests for QuestionValidity dataclass initialization.""" - - _MODULE = "pydantic_ai_lightspeed.capabilities.question_validity._capability" - - def test_post_init_calls_create_model(self, mocker: MockerFixture) -> None: - """Test that __post_init__ delegates to _create_model_from_llama_stack_client.""" - mock_create = mocker.patch( - f"{self._MODULE}._create_model_from_llama_stack_client", - ) - config = QuestionValidityConfig(model_id="my-model") - + config = QuestionValidityConfig(model_id="test-model") QuestionValidity(config=config) - mock_create.assert_called_once_with("my-model") + mock_holder.return_value.get_client.assert_called_once() + mock_from_client.assert_called_once_with( + "test-model", + mock_client, + model_settings=OpenAIResponsesModelSettings(openai_store=False), + ) def test_model_is_assigned_from_factory(self, mocker: MockerFixture) -> None: - """Test that the model returned by the factory is stored on the instance.""" + """Test that the model returned by from_llama_stack_client is stored.""" mock_model = mocker.Mock() + mocker.patch(f"{_MODULE}.AsyncLlamaStackClientHolder") mocker.patch( - f"{self._MODULE}._create_model_from_llama_stack_client", + f"{_MODULE}.LlamaStackResponsesModel.from_llama_stack_client", return_value=mock_model, ) config = QuestionValidityConfig(model_id="test") @@ -182,12 +147,11 @@ def test_model_is_assigned_from_factory(self, mocker: MockerFixture) -> None: class TestBuildPrompt: """Tests for QuestionValidity._build_prompt method.""" - _MODULE = "pydantic_ai_lightspeed.capabilities.question_validity._capability" - @pytest.fixture(autouse=True) def _mock_create_model(self, mocker: MockerFixture) -> None: - """Mock _create_model_from_llama_stack_client for all tests.""" - mocker.patch(f"{self._MODULE}._create_model_from_llama_stack_client") + """Mock model creation for all tests.""" + mocker.patch(f"{_MODULE}.AsyncLlamaStackClientHolder") + mocker.patch(f"{_MODULE}.LlamaStackResponsesModel.from_llama_stack_client") @pytest.fixture(name="question_validity") def question_validity_fixture(self) -> QuestionValidity: @@ -245,12 +209,11 @@ def test_custom_prompt_template(self) -> None: class TestWrapRun: """Tests for QuestionValidity.wrap_run method.""" - _MODULE = "pydantic_ai_lightspeed.capabilities.question_validity._capability" - @pytest.fixture(autouse=True) def _mock_create_model(self, mocker: MockerFixture) -> None: - """Mock _create_model_from_llama_stack_client for all tests.""" - mocker.patch(f"{self._MODULE}._create_model_from_llama_stack_client") + """Mock model creation for all tests.""" + mocker.patch(f"{_MODULE}.AsyncLlamaStackClientHolder") + mocker.patch(f"{_MODULE}.LlamaStackResponsesModel.from_llama_stack_client") @pytest.fixture(name="mock_ctx") def mock_ctx_fixture(self, mocker: MockerFixture) -> RunContext: diff --git a/tests/unit/pydantic_ai_lightspeed/llamastack/test_model.py b/tests/unit/pydantic_ai_lightspeed/llamastack/test_model.py new file mode 100644 index 000000000..271a00fbe --- /dev/null +++ b/tests/unit/pydantic_ai_lightspeed/llamastack/test_model.py @@ -0,0 +1,759 @@ +"""Unit tests for pydantic_ai_lightspeed.llamastack._model module.""" + +# pylint: disable=protected-access,too-few-public-methods + +import pytest +from openai.types import responses +from pydantic_ai import ModelMessage, UnexpectedModelBehavior +from pydantic_ai.messages import ModelResponse +from pydantic_ai.models.openai import ( + OpenAIResponsesModel, + OpenAIResponsesModelSettings, + OpenAIResponsesStreamedResponse, +) +from pydantic_ai.settings import ModelSettings +from pytest_mock import MockerFixture + +from models.common.responses.responses_api_params import ResponsesApiParams +from pydantic_ai_lightspeed.llamastack._model import ( + _LLS_RESPONSES_EXTRA_FIELDS, + LlamaStackResponsesModel, + _FilteredResponseStream, + _model_settings_from_responses_params, +) + +_REQUIRED_PARAMS = { + "input": "hello", + "model": "provider/model", + "conversation": "conv-1", + "store": True, + "stream": True, +} + + +def _make_params(**overrides: object) -> ResponsesApiParams: + """Build a ``ResponsesApiParams`` with required fields plus overrides.""" + return ResponsesApiParams(**{**_REQUIRED_PARAMS, **overrides}) + + +class TestModelSettingsFromResponsesParams: + """Tests for _model_settings_from_responses_params field mapping.""" + + def test_store_maps_to_openai_store(self) -> None: + """Test that store maps to openai_store.""" + params = _make_params(store=True) + settings = _model_settings_from_responses_params(params) + assert "openai_store" in settings + assert settings["openai_store"] is True + + def test_max_output_tokens_maps_to_max_tokens(self) -> None: + """Test that max_output_tokens maps to max_tokens.""" + params = _make_params(max_output_tokens=512) + settings = _model_settings_from_responses_params(params) + assert "max_tokens" in settings + assert settings["max_tokens"] == 512 + + def test_temperature(self) -> None: + """Test that temperature is passed through.""" + params = _make_params(temperature=0.7) + settings = _model_settings_from_responses_params(params) + assert "temperature" in settings + assert settings["temperature"] == 0.7 + + def test_parallel_tool_calls(self) -> None: + """Test that parallel_tool_calls is passed through.""" + params = _make_params(parallel_tool_calls=True) + settings = _model_settings_from_responses_params(params) + assert "parallel_tool_calls" in settings + assert settings["parallel_tool_calls"] is True + + def test_extra_headers(self) -> None: + """Test that extra_headers is converted to a dict.""" + params = _make_params(extra_headers={"X-Custom": "value"}) + settings = _model_settings_from_responses_params(params) + assert "extra_headers" in settings + assert settings["extra_headers"] == {"X-Custom": "value"} + + def test_previous_response_id_maps_to_openai_previous_response_id(self) -> None: + """Test that previous_response_id maps to openai_previous_response_id.""" + params = _make_params(previous_response_id="resp-42") + settings = _model_settings_from_responses_params(params) + assert "openai_previous_response_id" in settings + assert settings["openai_previous_response_id"] == "resp-42" + + def test_extra_body_fields(self) -> None: + """Test that fields in _LLS_RESPONSES_EXTRA_FIELDS land in extra_body.""" + params = _make_params( + max_infer_iters=5, + max_tool_calls=10, + tools=[{"type": "function", "name": "test-function", "parameters": {}}], + ) + settings = _model_settings_from_responses_params(params) + + assert "extra_body" in settings + assert isinstance(settings["extra_body"], dict) + assert settings["extra_body"]["max_infer_iters"] == 5 + assert settings["extra_body"]["max_tool_calls"] == 10 + assert settings["extra_body"]["conversation"] == "conv-1" + assert settings["extra_body"]["tools"] == [ + {"type": "function", "name": "test-function", "parameters": {}} + ] + + def test_none_fields_excluded(self) -> None: + """Test that None optional fields do not appear in the result.""" + params = _make_params() + settings = _model_settings_from_responses_params(params) + assert "max_tokens" not in settings + assert "temperature" not in settings + assert "parallel_tool_calls" not in settings + assert "extra_headers" not in settings + assert "openai_previous_response_id" not in settings + + +class TestFromLlamaStackClient: + """Tests for LlamaStackResponsesModel.from_llama_stack_client factory.""" + + def test_with_responses_params(self, mocker: MockerFixture) -> None: + """Test that responses_params is converted and forwarded.""" + mock_provider = mocker.Mock() + mocker.patch( + "pydantic_ai_lightspeed.llamastack._model.LlamaStackProvider" + ".from_llama_stack_client", + return_value=mock_provider, + ) + mock_init = mocker.patch.object( + LlamaStackResponsesModel, "__init__", return_value=None + ) + + params = _make_params(temperature=0.5) + client = mocker.Mock() + result = LlamaStackResponsesModel.from_llama_stack_client( + "test-model", client, responses_params=params + ) + assert isinstance(result, LlamaStackResponsesModel) + args, kwargs = mock_init.call_args + assert kwargs["settings"]["temperature"] == 0.5 + assert kwargs["provider"] is mock_provider + assert kwargs["profile"] is None + assert args[0] == "test-model" + + def test_with_model_settings(self, mocker: MockerFixture) -> None: + """Test that model_settings is forwarded directly.""" + mock_provider = mocker.Mock() + mocker.patch( + "pydantic_ai_lightspeed.llamastack._model.LlamaStackProvider" + ".from_llama_stack_client", + return_value=mock_provider, + ) + mock_init = mocker.patch.object( + LlamaStackResponsesModel, "__init__", return_value=None + ) + + settings: ModelSettings = {"temperature": 0.9} + client = mocker.Mock() + result = LlamaStackResponsesModel.from_llama_stack_client( + "test-model", client, model_settings=settings + ) + + assert isinstance(result, LlamaStackResponsesModel) + args, kwargs = mock_init.call_args + assert kwargs["settings"] is settings + assert kwargs["provider"] is mock_provider + assert kwargs["profile"] is None + assert args[0] == "test-model" + + def test_with_neither(self, mocker: MockerFixture) -> None: + """Test that settings is None when neither param is provided.""" + mock_provider = mocker.Mock() + mocker.patch( + "pydantic_ai_lightspeed.llamastack._model.LlamaStackProvider" + ".from_llama_stack_client", + return_value=mock_provider, + ) + mock_init = mocker.patch.object( + LlamaStackResponsesModel, "__init__", return_value=None + ) + + client = mocker.Mock() + result = LlamaStackResponsesModel.from_llama_stack_client("test-model", client) + + assert isinstance(result, LlamaStackResponsesModel) + args, kwargs = mock_init.call_args + assert kwargs["settings"] is None + assert kwargs["provider"] is mock_provider + assert kwargs["profile"] is None + assert args[0] == "test-model" + + def test_both_raises_value_error(self, mocker: MockerFixture) -> None: + """Test that providing both raises ValueError.""" + mocker.patch( + "pydantic_ai_lightspeed.llamastack._model.LlamaStackProvider" + ".from_llama_stack_client", + return_value=mocker.Mock(), + ) + + params = _make_params() + settings: ModelSettings = {"temperature": 0.5} + client = mocker.Mock() + + with pytest.raises(ValueError, match="ResponsesApiParams or ModelSetting"): + LlamaStackResponsesModel.from_llama_stack_client( + "test-model", + client, + responses_params=params, + model_settings=settings, + ) + + +class TestPrepareConversationContinuation: + """Tests for LlamaStackResponsesModel._prepare_conversation_continuation.""" + + @pytest.fixture(name="model") + def model_fixture(self, mocker: MockerFixture) -> LlamaStackResponsesModel: + """Create a LlamaStackResponsesModel with mocked __init__.""" + mocker.patch.object(LlamaStackResponsesModel, "__init__", return_value=None) + return LlamaStackResponsesModel("test-model") + + def test_none_settings_returns_unchanged( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test that None model_settings returns messages and settings unchanged.""" + messages = [mocker.Mock()] + result_msgs, result_settings = model._prepare_conversation_continuation( + messages, None + ) + assert result_msgs is messages + assert result_settings is None + + def test_empty_settings_returns_unchanged( + self, model: LlamaStackResponsesModel + ) -> None: + """Test that empty dict model_settings returns unchanged.""" + messages: list = [] + settings: ModelSettings = {} + result_msgs, result_settings = model._prepare_conversation_continuation( + messages, settings + ) + assert result_msgs is messages + assert result_settings is settings + + def test_no_extra_body_returns_unchanged( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test that settings without extra_body returns unchanged.""" + messages = [mocker.Mock()] + settings: ModelSettings = {"temperature": 0.5} + result_msgs, result_settings = model._prepare_conversation_continuation( + messages, settings + ) + assert result_msgs is messages + assert result_settings is settings + + def test_extra_body_without_conversation_returns_unchanged( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test that extra_body without 'conversation' key returns unchanged.""" + messages = [mocker.Mock()] + settings: ModelSettings = {"extra_body": {"max_infer_iters": 5}} + result_msgs, result_settings = model._prepare_conversation_continuation( + messages, settings + ) + assert result_msgs is messages + assert result_settings is settings + + def test_no_model_response_returns_unchanged( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test that messages without ModelResponse returns unchanged.""" + messages = [mocker.Mock(), mocker.Mock()] + settings: ModelSettings = {"extra_body": {"conversation": "conv-1"}} + result_msgs, result_settings = model._prepare_conversation_continuation( + messages, settings + ) + assert result_msgs is messages + assert result_settings is settings + + def test_model_response_without_provider_id_returns_unchanged( + self, model: LlamaStackResponsesModel + ) -> None: + """Test that ModelResponse without provider_response_id is ignored.""" + response_msg = ModelResponse(parts=[], provider_response_id=None) + messages: list[ModelMessage] = [response_msg] + settings: ModelSettings = {"extra_body": {"conversation": "conv-1"}} + result_msgs, result_settings = model._prepare_conversation_continuation( + messages, settings + ) + assert result_msgs is messages + assert result_settings is settings + + def test_trims_messages_and_strips_previous_response_id( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test that messages are trimmed and previous_response_id is removed.""" + msg_before = mocker.Mock() + response_msg = ModelResponse(parts=[], provider_response_id="resp-1") + msg_after = mocker.Mock() + messages = [msg_before, response_msg, msg_after] + settings: OpenAIResponsesModelSettings = { + "extra_body": {"conversation": "conv-1"}, + "openai_previous_response_id": "resp-1", + } + result_msgs, result_settings = model._prepare_conversation_continuation( + messages, settings + ) + assert result_msgs == [msg_after] + assert result_settings is not None + assert "openai_previous_response_id" not in result_settings + assert "extra_body" in result_settings + assert result_settings["extra_body"] == {"conversation": "conv-1"} + + def test_trims_without_previous_response_id_in_settings( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test trimming works when settings lacks previous_response_id.""" + response_msg = ModelResponse(parts=[], provider_response_id="resp-1") + msg_after = mocker.Mock() + messages = [response_msg, msg_after] + settings: ModelSettings = {"extra_body": {"conversation": "conv-1"}} + result_msgs, result_settings = model._prepare_conversation_continuation( + messages, settings + ) + assert result_msgs == [msg_after] + assert result_settings is not None + assert "openai_previous_response_id" not in result_settings + + def test_uses_last_model_response_when_multiple( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test that the last ModelResponse with provider_response_id is used.""" + msg1 = mocker.Mock() + resp1 = ModelResponse(parts=[], provider_response_id="resp-1") + msg2 = mocker.Mock() + resp2 = ModelResponse(parts=[], provider_response_id="resp-2") + msg3 = mocker.Mock() + messages = [msg1, resp1, msg2, resp2, msg3] + settings: OpenAIResponsesModelSettings = { + "extra_body": {"conversation": "conv-1"}, + "openai_previous_response_id": "resp-2", + } + result_msgs, result_settings = model._prepare_conversation_continuation( + messages, settings + ) + assert result_msgs == [msg3] + assert result_settings is not None + assert "openai_previous_response_id" not in result_settings + + def test_only_skip_model_response_with_provider_response_id( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test that the last ModelResponse with provider_response_id is used.""" + msg1 = mocker.Mock() + resp1 = ModelResponse(parts=[], provider_response_id="resp-1") + msg2 = mocker.Mock() + resp2 = ModelResponse(parts=[]) + msg3 = mocker.Mock() + messages = [msg1, resp1, msg2, resp2, msg3] + settings: OpenAIResponsesModelSettings = { + "extra_body": {"conversation": "conv-1"}, + "openai_previous_response_id": "resp-2", + } + result_msgs, result_settings = model._prepare_conversation_continuation( + messages, settings + ) + assert result_msgs == [msg2, resp2, msg3] + assert result_settings is not None + assert "openai_previous_response_id" not in result_settings + + def test_does_not_mutate_original_settings( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test that the original settings dict is not modified.""" + response_msg = ModelResponse(parts=[], provider_response_id="resp-1") + messages = [response_msg, mocker.Mock()] + settings: OpenAIResponsesModelSettings = { + "extra_body": {"conversation": "conv-1"}, + "openai_previous_response_id": "resp-1", + } + model._prepare_conversation_continuation(messages, settings) + assert "openai_previous_response_id" in settings + + +class TestRequest: + """Tests for LlamaStackResponsesModel.request.""" + + @pytest.mark.asyncio + async def test_calls_prepare_and_delegates_to_super( + self, mocker: MockerFixture + ) -> None: + """Test that request calls _prepare_conversation_continuation and delegates.""" + mocker.patch.object(LlamaStackResponsesModel, "__init__", return_value=None) + model = LlamaStackResponsesModel("test-model") + + original_msgs = [mocker.Mock()] + original_settings: OpenAIResponsesModelSettings = { + "temperature": 0.3, + "openai_previous_response_id": "resp-1", + } + prepared_msgs = [mocker.Mock()] + prepared_settings: OpenAIResponsesModelSettings = {"temperature": 0.3} + + mock_prepare = mocker.patch.object( + model, + "_prepare_conversation_continuation", + return_value=(prepared_msgs, prepared_settings), + ) + + expected_result = mocker.Mock() + mock_super_request = mocker.patch.object( + OpenAIResponsesModel, + "request", + new_callable=mocker.AsyncMock, + return_value=expected_result, + ) + + mock_params = mocker.Mock() + result = await model.request(original_msgs, original_settings, mock_params) + + mock_prepare.assert_called_once_with(original_msgs, original_settings) + mock_super_request.assert_called_once() + args, _ = mock_super_request.call_args + assert args[0] is prepared_msgs + assert args[1] is prepared_settings + assert args[2] is mock_params + assert result is expected_result + + +def _make_mock_response_stream(mocker: MockerFixture, events: list): + """Create a mock AsyncStream that yields given events and supports async with.""" + mock = mocker.Mock() + mock.__aiter__ = lambda _: _async_iter(events) + mock.__aenter__ = mocker.AsyncMock(return_value=mock) + mock.__aexit__ = mocker.AsyncMock(return_value=False) + mock.close = mocker.AsyncMock() + return mock + + +def _make_response_created_event() -> responses.ResponseCreatedEvent: + """Build a ResponseCreatedEvent for stream tests.""" + return responses.ResponseCreatedEvent( + response=responses.Response( + id="resp-1", + created_at=0, + model="test", + object="response", + output=[], + parallel_tool_calls=False, + tool_choice="auto", + tools=[], + status="completed", + ), + sequence_number=0, + type="response.created", + ) + + +class TestRequestStream: + """Tests for LlamaStackResponsesModel.request_stream.""" + + @pytest.fixture(name="model") + def model_fixture(self, mocker: MockerFixture) -> LlamaStackResponsesModel: + """Create a LlamaStackResponsesModel with stream-related attributes set.""" + mocker.patch.object(LlamaStackResponsesModel, "__init__", return_value=None) + model = LlamaStackResponsesModel("test-model") + mocker.patch.object( + type(model), + "model_name", + new_callable=mocker.PropertyMock, + return_value="test-model", + ) + model._provider = mocker.Mock() + model._provider.name = "test-provider" + model._provider.base_url = "http://localhost" + mocker.patch( + "pydantic_ai_lightspeed.llamastack._model.check_allow_model_requests" + ) + return model + + @pytest.mark.asyncio + async def test_calls_prepare_continuation( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test that request_stream calls _prepare_conversation_continuation.""" + original_msgs = [mocker.Mock()] + original_settings: ModelSettings = {"temperature": 0.5} + + mock_prepare = mocker.patch.object( + model, + "_prepare_conversation_continuation", + return_value=(original_msgs, original_settings), + ) + + mock_stream = _make_mock_response_stream( + mocker, [_make_response_created_event()] + ) + model._responses_create = mocker.AsyncMock(return_value=mock_stream) + async with model.request_stream( + original_msgs, original_settings, mocker.Mock() + ): + pass + + mock_prepare.assert_called_once_with(original_msgs, original_settings) + + @pytest.mark.asyncio + async def test_empty_stream_raises( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test that an empty stream raises UnexpectedModelBehavior.""" + mocker.patch.object( + model, + "_prepare_conversation_continuation", + return_value=([mocker.Mock()], {}), + ) + + mock_stream = _make_mock_response_stream(mocker, []) + model._responses_create = mocker.AsyncMock(return_value=mock_stream) + with pytest.raises(UnexpectedModelBehavior, match="ended without content"): + async with model.request_stream([mocker.Mock()], {}, mocker.Mock()): + pass + + @pytest.mark.asyncio + async def test_wrong_first_event_raises( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test that a non-ResponseCreatedEvent first event raises.""" + mocker.patch.object( + model, + "_prepare_conversation_continuation", + return_value=([mocker.Mock()], {}), + ) + + wrong_event = responses.ResponseCompletedEvent( + response=responses.Response( + id="resp-1", + created_at=0, + model="test", + object="response", + output=[], + parallel_tool_calls=False, + tool_choice="auto", + tools=[], + status="completed", + ), + sequence_number=0, + type="response.completed", + ) + mock_stream = _make_mock_response_stream(mocker, [wrong_event]) + model._responses_create = mocker.AsyncMock(return_value=mock_stream) + with pytest.raises( + UnexpectedModelBehavior, match="Expected ResponseCreatedEvent" + ): + async with model.request_stream([mocker.Mock()], {}, mocker.Mock()): + pass + + @pytest.mark.asyncio + async def test_happy_path_yields_streamed_response( + self, model: LlamaStackResponsesModel, mocker: MockerFixture + ) -> None: + """Test that a valid stream yields an OpenAIResponsesStreamedResponse.""" + mocker.patch.object( + model, + "_prepare_conversation_continuation", + return_value=([mocker.Mock()], {}), + ) + + mock_stream = _make_mock_response_stream( + mocker, [_make_response_created_event()] + ) + model._responses_create = mocker.AsyncMock(return_value=mock_stream) + async with model.request_stream([mocker.Mock()], {}, mocker.Mock()) as streamed: + assert isinstance(streamed, OpenAIResponsesStreamedResponse) + + +class TestLlsResponsesExtraFields: + """Tests for the _LLS_RESPONSES_EXTRA_FIELDS constant.""" + + def test_is_frozenset(self) -> None: + """Test that _LLS_RESPONSES_EXTRA_FIELDS is a frozenset.""" + assert isinstance(_LLS_RESPONSES_EXTRA_FIELDS, frozenset) + + def test_contains_expected_fields(self) -> None: + """Test that key fields are present.""" + expected = { + "conversation", + "max_infer_iters", + "tools", + "tool_choice", + "include", + "text", + "reasoning", + "prompt", + "metadata", + "max_tool_calls", + "safety_identifier", + } + assert expected == _LLS_RESPONSES_EXTRA_FIELDS + + +def _make_delta( + item_id: str, delta: str, output_index: int = 0, seq: int = 1 +) -> responses.ResponseFunctionCallArgumentsDeltaEvent: + """Build a ResponseFunctionCallArgumentsDeltaEvent.""" + return responses.ResponseFunctionCallArgumentsDeltaEvent( + delta=delta, + item_id=item_id, + output_index=output_index, + sequence_number=seq, + type="response.function_call_arguments.delta", + ) + + +def _make_function_tool_call_added( + item_id: str, +) -> responses.ResponseOutputItemAddedEvent: + """Build a ResponseOutputItemAddedEvent for a ResponseFunctionToolCall.""" + item = responses.ResponseFunctionToolCall( + id=item_id, + call_id="call-1", + name="my_tool", + arguments="", + type="function_call", + ) + return responses.ResponseOutputItemAddedEvent( + item=item, + output_index=0, + sequence_number=0, + type="response.output_item.added", + ) + + +def _make_mcp_call_added(item_id: str) -> responses.ResponseOutputItemAddedEvent: + """Build a ResponseOutputItemAddedEvent for an McpCall.""" + item = responses.response_output_item.McpCall( + id=item_id, + arguments="", + name="mcp_tool", + server_label="server", + type="mcp_call", + ) + return responses.ResponseOutputItemAddedEvent( + item=item, + output_index=0, + sequence_number=0, + type="response.output_item.added", + ) + + +async def _async_iter(events): + """Turn a list of events into an async iterator.""" + for e in events: + yield e + + +class TestFilteredResponseStream: + """Tests for _FilteredResponseStream event reordering.""" + + @pytest.mark.asyncio + async def test_passthrough_normal_events(self, mocker: MockerFixture) -> None: + """Test that non-tool events pass through unchanged.""" + event = responses.ResponseCompletedEvent( + response=responses.Response( + id="resp-1", + created_at=0, + model="test", + object="response", + output=[], + parallel_tool_calls=False, + tool_choice="auto", + tools=[], + status="completed", + ), + sequence_number=0, + type="response.completed", + ) + source = mocker.Mock() + source.__aiter__ = lambda _: _async_iter([event]) + + stream = _FilteredResponseStream(source) + result = [e async for e in stream] + + assert result == [event] + + @pytest.mark.asyncio + async def test_buffers_early_argument_delta(self, mocker: MockerFixture) -> None: + """Test that a delta before its announcement is buffered and not yielded.""" + delta = _make_delta("item-1", '{"key":') + source = mocker.Mock() + source.__aiter__ = lambda _: _async_iter([delta]) + + stream = _FilteredResponseStream(source) + result = [e async for e in stream] + + assert result == [] + + @pytest.mark.asyncio + async def test_replays_buffered_deltas_for_function_tool_call( + self, mocker: MockerFixture + ) -> None: + """Test that buffered deltas replay after a FunctionToolCall announcement.""" + delta1 = _make_delta("item-1", '{"key":', seq=1) + delta2 = _make_delta("item-1", '"val"}', seq=2) + announcement = _make_function_tool_call_added("item-1") + + source = mocker.Mock() + source.__aiter__ = lambda _: _async_iter([delta1, delta2, announcement]) + + stream = _FilteredResponseStream(source) + result = [e async for e in stream] + + assert result[0] is announcement + assert result[1] is delta1 + assert result[2] is delta2 + + @pytest.mark.asyncio + async def test_replays_mcp_buffered_deltas_with_suffixed_id( + self, mocker: MockerFixture + ) -> None: + """Test that MCP deltas are combined with -call suffix on item_id.""" + delta1 = _make_delta("mcp-1", '{"arg":', seq=1) + delta2 = _make_delta("mcp-1", '"v"}', seq=2) + announcement = _make_mcp_call_added("mcp-1") + + source = mocker.Mock() + source.__aiter__ = lambda _: _async_iter([delta1, delta2, announcement]) + + stream = _FilteredResponseStream(source) + result = [e async for e in stream] + + assert result[0] is announcement + replayed = result[1] + assert isinstance(replayed, responses.ResponseFunctionCallArgumentsDeltaEvent) + assert replayed.item_id == "mcp-1-call" + assert replayed.delta == '{"arg":"v"}}' + + @pytest.mark.asyncio + async def test_no_buffered_deltas(self, mocker: MockerFixture) -> None: + """Test that an announcement with no prior deltas yields only itself.""" + announcement = _make_function_tool_call_added("item-1") + source = mocker.Mock() + source.__aiter__ = lambda _: _async_iter([announcement]) + + stream = _FilteredResponseStream(source) + result = [e async for e in stream] + + assert result == [announcement] + + @pytest.mark.asyncio + async def test_delta_after_announcement_passes_through( + self, mocker: MockerFixture + ) -> None: + """Test that a delta arriving after its announcement passes through.""" + announcement = _make_function_tool_call_added("item-1") + delta = _make_delta("item-1", '{"key":"val"}') + + source = mocker.Mock() + source.__aiter__ = lambda _: _async_iter([announcement, delta]) + + stream = _FilteredResponseStream(source) + result = [e async for e in stream] + + assert result == [announcement, delta] diff --git a/tests/unit/pydantic_ai_lightspeed/llamastack/test_provider.py b/tests/unit/pydantic_ai_lightspeed/llamastack/test_provider.py index 8773fef3d..dd49739af 100644 --- a/tests/unit/pydantic_ai_lightspeed/llamastack/test_provider.py +++ b/tests/unit/pydantic_ai_lightspeed/llamastack/test_provider.py @@ -4,6 +4,8 @@ import httpx import pytest +from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient +from llama_stack_client import AsyncLlamaStackClient from openai import AsyncOpenAI from pytest_mock import MockerFixture @@ -143,6 +145,97 @@ def test_library_client_and_http_client_raises(self, mocker: MockerFixture) -> N ) +class TestFromLlamaStackClient: + """Tests for LlamaStackProvider.from_llama_stack_client.""" + + def test_library_client_dispatches_to_library_mode( + self, mocker: MockerFixture + ) -> None: + """Test that an AsyncLlamaStackAsLibraryClient creates a library-mode provider.""" + mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient) + mock_lib_client.provider_data = None + + provider = LlamaStackProvider.from_llama_stack_client(mock_lib_client) + + assert provider._library_client is mock_lib_client + assert "llama-stack-library" in provider.base_url + + def test_server_client_extracts_base_url_with_v1( + self, mocker: MockerFixture + ) -> None: + """Test that a server client whose base_url already ends with /v1 is used as-is.""" + mock_client = mocker.Mock(spec=AsyncLlamaStackClient) + mock_client.base_url = "http://my-server:8321/v1" + mock_client.api_key = "test-key" + mock_client._client = mocker.Mock(spec=httpx.AsyncClient) + + provider = LlamaStackProvider.from_llama_stack_client(mock_client) + + assert "my-server:8321/v1" in provider.base_url + assert provider.base_url.count("/v1") == 1 + + def test_server_client_appends_v1_when_missing(self, mocker: MockerFixture) -> None: + """Test that /v1 is appended when the server client's base_url lacks it.""" + mock_client = mocker.Mock(spec=AsyncLlamaStackClient) + mock_client.base_url = "http://my-server:8321" + mock_client.api_key = "test-key" + mock_client._client = mocker.Mock(spec=httpx.AsyncClient) + + provider = LlamaStackProvider.from_llama_stack_client(mock_client) + + assert provider.base_url.rstrip("/").endswith("/v1") + + def test_server_client_strips_trailing_slash_before_appending_v1( + self, mocker: MockerFixture + ) -> None: + """Test that a trailing slash is stripped before appending /v1.""" + mock_client = mocker.Mock(spec=AsyncLlamaStackClient) + mock_client.base_url = "http://my-server:8321/" + mock_client.api_key = "test-key" + mock_client._client = mocker.Mock(spec=httpx.AsyncClient) + + provider = LlamaStackProvider.from_llama_stack_client(mock_client) + + assert "//v1" not in provider.base_url + assert provider.base_url.rstrip("/").endswith("/v1") + + def test_server_client_uses_provided_api_key(self, mocker: MockerFixture) -> None: + """Test that the server client's api_key is forwarded to the provider.""" + mock_client = mocker.Mock(spec=AsyncLlamaStackClient) + mock_client.base_url = "http://my-server:8321/v1" + mock_client.api_key = "my-secret" + mock_client._client = mocker.Mock(spec=httpx.AsyncClient) + + provider = LlamaStackProvider.from_llama_stack_client(mock_client) + + assert provider.client.api_key == "my-secret" + + def test_server_client_defaults_api_key_when_none( + self, mocker: MockerFixture + ) -> None: + """Test that a None api_key falls back to 'not-needed'.""" + mock_client = mocker.Mock(spec=AsyncLlamaStackClient) + mock_client.base_url = "http://my-server:8321/v1" + mock_client.api_key = None + mock_client._client = mocker.Mock(spec=httpx.AsyncClient) + + provider = LlamaStackProvider.from_llama_stack_client(mock_client) + + assert provider.client.api_key == "not-needed" + + def test_server_client_passes_http_client(self, mocker: MockerFixture) -> None: + """Test that the server client's internal httpx client is reused.""" + mock_client = mocker.Mock(spec=AsyncLlamaStackClient) + mock_client.base_url = "http://my-server:8321/v1" + mock_client.api_key = "test-key" + inner_http = mocker.Mock(spec=httpx.AsyncClient) + mock_client._client = inner_http + + provider = LlamaStackProvider.from_llama_stack_client(mock_client) + + assert provider._client._client is inner_http + + class TestSetHttpClient: # pylint: disable=too-few-public-methods """Tests for LlamaStackProvider._set_http_client.""" diff --git a/tests/unit/utils/test_pydantic_ai.py b/tests/unit/utils/test_pydantic_ai.py index 3948d7c50..b4089f9e3 100644 --- a/tests/unit/utils/test_pydantic_ai.py +++ b/tests/unit/utils/test_pydantic_ai.py @@ -3,7 +3,6 @@ # pylint: disable=protected-access import httpx -import pytest from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient from llama_stack_client import AsyncLlamaStackClient from pydantic_ai_skills import SkillsCapability @@ -12,197 +11,13 @@ from models.common.responses.responses_api_params import ResponsesApiParams from models.config import SkillsConfiguration from utils.pydantic_ai import ( - _LLS_RESPONSES_EXTRA_FIELDS, _agent_capabilities, - _model_settings_from_responses_params, _skills_capability, build_agent, get_agent_capability_tools, - llama_stack_provider_from_client, ) -class TestLlamaStackProviderFromClient: - """Tests for llama_stack_provider_from_client factory.""" - - def test_library_client(self, mocker: MockerFixture) -> None: - """Test that a library client creates a provider with library_client kwarg.""" - mock_lib_client = mocker.Mock(spec=AsyncLlamaStackAsLibraryClient) - mock_lib_client.provider_data = None - - provider = llama_stack_provider_from_client(mock_lib_client) - - assert provider._library_client is mock_lib_client - - def test_remote_client_with_api_key(self, mocker: MockerFixture) -> None: - """Test that a remote client uses its api_key.""" - mock_client = mocker.Mock() - mock_client.base_url = "http://my-server:8321" - mock_client.api_key = "my-secret" - mock_client._client = mocker.Mock(spec=httpx.AsyncClient) - - provider = llama_stack_provider_from_client(mock_client) - - assert provider.client.api_key == "my-secret" - assert "my-server:8321" in provider.base_url - - def test_remote_client_without_api_key(self, mocker: MockerFixture) -> None: - """Test that a remote client without api_key defaults to 'not-needed'.""" - mock_client = mocker.Mock() - mock_client.base_url = "http://my-server:8321" - mock_client.api_key = None - mock_client._client = mocker.Mock(spec=httpx.AsyncClient) - - provider = llama_stack_provider_from_client(mock_client) - - assert provider.client.api_key == "not-needed" - - def test_remote_client_passes_http_client(self, mocker: MockerFixture) -> None: - """Test that a remote client's internal http_client is forwarded.""" - mock_http_client = mocker.Mock(spec=httpx.AsyncClient) - mock_client = mocker.Mock() - mock_client.base_url = "http://my-server:8321" - mock_client.api_key = "key" - mock_client._client = mock_http_client - - provider = llama_stack_provider_from_client(mock_client) - - assert provider._client._client is mock_http_client - - -class TestModelSettingsFromResponsesParams: - """Tests for _model_settings_from_responses_params mapping.""" - - @pytest.fixture(name="minimal_params") - def minimal_params_fixture(self, mocker: MockerFixture) -> object: - """Create minimal ResponsesApiParams mock with required fields only.""" - params = mocker.Mock() - params.model_dump.return_value = {"model": "test/model", "input": "hello"} - params.max_output_tokens = None - params.temperature = None - params.parallel_tool_calls = None - params.extra_headers = None - params.store = False - params.tools = None - params.previous_response_id = None - return params - - def test_minimal_params_returns_store_false(self, minimal_params: object) -> None: - """Test that minimal params produce settings with openai_store=False.""" - settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] - assert settings["openai_store"] is False - - def test_minimal_params_no_extra_body(self, minimal_params: object) -> None: - """Test that minimal params without extra fields omit extra_body.""" - settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] - assert "extra_body" not in settings - - def test_max_output_tokens_mapped(self, minimal_params: object) -> None: - """Test that max_output_tokens is mapped to max_tokens.""" - minimal_params.max_output_tokens = 1024 # type: ignore[attr-defined] - settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] - assert settings["max_tokens"] == 1024 - - def test_temperature_mapped(self, minimal_params: object) -> None: - """Test that temperature is passed through.""" - minimal_params.temperature = 0.7 # type: ignore[attr-defined] - settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] - assert settings["temperature"] == 0.7 - - def test_parallel_tool_calls_mapped(self, minimal_params: object) -> None: - """Test that parallel_tool_calls is passed through.""" - minimal_params.parallel_tool_calls = True # type: ignore[attr-defined] - settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] - assert settings["parallel_tool_calls"] is True - - def test_extra_headers_mapped(self, minimal_params: object) -> None: - """Test that extra_headers are converted to a dict.""" - minimal_params.extra_headers = {"x-custom": "value"} # type: ignore[attr-defined] - settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] - assert settings["extra_headers"] == {"x-custom": "value"} - - def test_store_true_mapped(self, minimal_params: object) -> None: - """Test that store=True is passed as openai_store.""" - minimal_params.store = True # type: ignore[attr-defined] - settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] - assert settings["openai_store"] is True - - def test_previous_response_id_mapped(self, minimal_params: object) -> None: - """Test that previous_response_id is passed as openai_previous_response_id.""" - minimal_params.previous_response_id = "resp_abc123" # type: ignore[attr-defined] - settings = _model_settings_from_responses_params(minimal_params) # type: ignore[arg-type] - assert settings["openai_previous_response_id"] == "resp_abc123" - - def test_extra_body_from_lls_fields(self, mocker: MockerFixture) -> None: - """Test that LLS-specific fields are placed into extra_body.""" - params = mocker.Mock() - params.model_dump.return_value = { - "model": "test/model", - "conversation": "conv-123", - "max_infer_iters": 5, - "tool_choice": "auto", - } - params.max_output_tokens = None - params.temperature = None - params.parallel_tool_calls = None - params.extra_headers = None - params.store = False - params.previous_response_id = None - params.tools = [{"type": "function"}] - - settings = _model_settings_from_responses_params(params) - - assert "extra_body" in settings - assert settings["extra_body"]["conversation"] == "conv-123" - assert settings["extra_body"]["max_infer_iters"] == 5 - assert settings["extra_body"]["tool_choice"] == "auto" - assert settings["openai_native_tools"] == [{"type": "function"}] - - def test_extra_body_only_includes_known_fields(self, mocker: MockerFixture) -> None: - """Test that extra_body only includes fields in _LLS_RESPONSES_EXTRA_FIELDS.""" - params = mocker.Mock() - params.model_dump.return_value = { - "model": "test/model", - "conversation": "conv-1", - "unknown_field": "should-not-appear", - } - params.max_output_tokens = None - params.temperature = None - params.parallel_tool_calls = None - params.extra_headers = None - params.store = False - params.previous_response_id = None - - settings = _model_settings_from_responses_params(params) - - assert "unknown_field" not in settings.get("extra_body", {}) - assert settings["extra_body"]["conversation"] == "conv-1" - - -class TestLlsResponsesExtraFields: - """Tests for the _LLS_RESPONSES_EXTRA_FIELDS constant.""" - - def test_is_frozenset(self) -> None: - """Test that _LLS_RESPONSES_EXTRA_FIELDS is a frozenset.""" - assert isinstance(_LLS_RESPONSES_EXTRA_FIELDS, frozenset) - - def test_contains_expected_fields(self) -> None: - """Test that key fields are present.""" - expected = { - "conversation", - "max_infer_iters", - "tool_choice", - "include", - "text", - "reasoning", - "prompt", - "metadata", - "max_tool_calls", - "safety_identifier", - } - assert expected == _LLS_RESPONSES_EXTRA_FIELDS - - class TestSkillsCapability: """Tests for _skills_capability."""