Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
QuestionValidityConfig,
)
from pydantic_ai_lightspeed.llamastack import LlamaStackResponsesModel
from utils.pydantic_ai import llama_stack_provider_from_client

logger = get_logger(__name__)

Expand Down Expand Up @@ -55,21 +54,6 @@ def _extract_message_str_from_user_content(user_content: Sequence[UserContent])
return "\n".join(str_arr)


def _create_model_from_llama_stack_client(model_id: str) -> LlamaStackResponsesModel:
"""Create a LlamaStackResponsesModel from the shared Llama Stack client.

Parameters:
model_id: The model identifier to use for the responses model.

Returns:
A configured LlamaStackResponsesModel instance.
"""
client = AsyncLlamaStackClientHolder().get_client()
provider = llama_stack_provider_from_client(client)
settings = OpenAIResponsesModelSettings(openai_store=False)
return LlamaStackResponsesModel(model_id, provider=provider, settings=settings)


@dataclass
class QuestionValidity(AbstractCapability[None]):
"""Block or modify user input based on a guardrail check.
Expand All @@ -91,7 +75,13 @@ class QuestionValidity(AbstractCapability[None]):

def __post_init__(self) -> None:
"""Initialize the model instance from the configured model ID."""
self._model = _create_model_from_llama_stack_client(self.config.model_id)
llama_stack_client = AsyncLlamaStackClientHolder().get_client()

self._model = LlamaStackResponsesModel.from_llama_stack_client(
self.config.model_id,
llama_stack_client,
model_settings=OpenAIResponsesModelSettings(openai_store=False),
)

def _build_prompt(self, message: str | Sequence[UserContent] | None) -> str:
"""Build the classification prompt from the user message.
Expand Down
98 changes: 97 additions & 1 deletion src/pydantic_ai_lightspeed/llamastack/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from collections import defaultdict
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from typing import Any, cast
from typing import Any, Final, cast

from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
from llama_stack_client import AsyncLlamaStackClient
from openai import AsyncStream
from openai.types import responses
from pydantic_ai import UnexpectedModelBehavior
Expand All @@ -38,12 +40,56 @@
OpenAIResponsesStreamedResponse,
_map_api_errors,
)
from pydantic_ai.profiles import ModelProfileSpec
from pydantic_ai.settings import ModelSettings

from log import get_logger
from models.common.responses.responses_api_params import ResponsesApiParams
from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider

logger = get_logger(__name__)

_LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset(
{
"conversation",
"max_infer_iters",
"tools",
"tool_choice",
"include",
"text",
"reasoning",
"prompt",
"metadata",
"max_tool_calls",
"safety_identifier",
}
)


def _model_settings_from_responses_params(
responses_params: ResponsesApiParams,
) -> OpenAIResponsesModelSettings:
"""Map ``ResponsesApiParams`` into Pydantic AI OpenAI Responses model settings."""
Comment thread
asimurka marked this conversation as resolved.
payload = responses_params.model_dump(exclude_none=True)
extra_body = {k: v for k, v in payload.items() if k in _LLS_RESPONSES_EXTRA_FIELDS}
settings_dict: dict[str, Any] = {}
if extra_body:
settings_dict["extra_body"] = extra_body
if responses_params.max_output_tokens is not None:
settings_dict["max_tokens"] = responses_params.max_output_tokens
if responses_params.temperature is not None:
settings_dict["temperature"] = responses_params.temperature
if responses_params.parallel_tool_calls is not None:
settings_dict["parallel_tool_calls"] = responses_params.parallel_tool_calls
if responses_params.extra_headers:
settings_dict["extra_headers"] = dict(responses_params.extra_headers)
settings_dict["openai_store"] = responses_params.store
if responses_params.previous_response_id is not None:
settings_dict["openai_previous_response_id"] = (
responses_params.previous_response_id
)
return cast(OpenAIResponsesModelSettings, settings_dict)


class _FilteredResponseStream:
"""Wraps an OpenAI AsyncStream to reorder spurious events from Llama Stack.
Expand Down Expand Up @@ -307,3 +353,53 @@ async def request_stream( # pylint: disable=unused-argument
else None
),
)

@staticmethod
def from_llama_stack_client(
model_name: str,
client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient,
*,
responses_params: ResponsesApiParams | None = None,
model_settings: ModelSettings | None = None,
profile: ModelProfileSpec | None = None,
) -> LlamaStackResponsesModel:
"""Create a ``LlamaStackResponsesModel`` from a Llama Stack client.

Mirrors ``OpenAIResponsesModel.__init__`` parameters, but accepts a
Llama Stack client instead of a provider. Exactly one of
``responses_params`` or ``model_settings`` may be provided.

Args:
model_name: The model name/ID to use.
client: Llama Stack client to build the provider from.
responses_params: Optional ``ResponsesApiParams``, converted to
``OpenAIResponsesModelSettings`` internally. Mutually
exclusive with ``model_settings``.
model_settings: Optional raw ``ModelSettings`` passed through
directly. Mutually exclusive with ``responses_params``.
profile: Optional model profile specification.

Raises:
ValueError: If both ``responses_params`` and ``model_settings``
are provided.

Returns:
Configured ``LlamaStackResponsesModel`` instance.
"""
provider = LlamaStackProvider.from_llama_stack_client(client)

if responses_params is not None and model_settings is not None:
raise ValueError(
"You can only pass either ResponsesApiParams or ModelSetting not both."
)

_settings: OpenAIResponsesModelSettings | ModelSettings | None = None

if responses_params is not None:
_settings = _model_settings_from_responses_params(responses_params)
elif model_settings is not None:
_settings = model_settings

return LlamaStackResponsesModel(
model_name, provider=provider, profile=profile, settings=_settings
)
Comment on lines +389 to +405

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🩺 Stability & Availability | 🟠 Major | ⚡ Quick win

Validate exclusivity before constructing the provider.

LlamaStackProvider.from_llama_stack_client(client) runs before the responses_params/model_settings check. On the library-client path, provider construction creates a new httpx.AsyncClient, so the error path allocates resources before immediately raising ValueError.

Suggested fix
-        provider = LlamaStackProvider.from_llama_stack_client(client)
-
         if responses_params is not None and model_settings is not None:
             raise ValueError(
                 "You can only pass either ResponsesApiParams or ModelSetting not both."
             )
         elif responses_params is not None:
             _settings = _model_settings_from_responses_params(responses_params)
         elif model_settings is not None:
             _settings = model_settings
         else:
             _settings = None
+
+        provider = LlamaStackProvider.from_llama_stack_client(client)
 
         return LlamaStackResponsesModel(
             model_name, provider=provider, profile=profile, settings=_settings
         )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
provider = LlamaStackProvider.from_llama_stack_client(client)
if responses_params is not None and model_settings is not None:
raise ValueError(
"You can only pass either ResponsesApiParams or ModelSetting not both."
)
elif responses_params is not None:
_settings = _model_settings_from_responses_params(responses_params)
elif model_settings is not None:
_settings = model_settings
else:
_settings = None
return LlamaStackResponsesModel(
model_name, provider=provider, profile=profile, settings=_settings
)
if responses_params is not None and model_settings is not None:
raise ValueError(
"You can only pass either ResponsesApiParams or ModelSetting not both."
)
elif responses_params is not None:
_settings = _model_settings_from_responses_params(responses_params)
elif model_settings is not None:
_settings = model_settings
else:
_settings = None
provider = LlamaStackProvider.from_llama_stack_client(client)
return LlamaStackResponsesModel(
model_name, provider=provider, profile=profile, settings=_settings
)
🧰 Tools
🪛 GitHub Actions: Type checks / 0_mypy.txt

[error] 398-398: mypy error: Incompatible types in assignment (expression has type "ModelSettings", variable has type "OpenAIResponsesModelSettings"). [assignment]

🪛 GitHub Actions: Type checks / mypy

[error] 398-398: mypy failed: Incompatible types in assignment (expression has type "ModelSettings", variable has type "OpenAIResponsesModelSettings"). [assignment]

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@src/pydantic_ai_lightspeed/llamastack/_model.py` around lines 389 - 404, Move
the mutual-exclusivity validation for responses_params and model_settings ahead
of LlamaStackProvider.from_llama_stack_client(client) in the
LlamaStackResponsesModel construction path, so the ValueError is raised before
any provider/client resources are created. Keep the existing logic in the same
method that builds LlamaStackResponsesModel, but check the two inputs first and
only instantiate the provider after the validation passes.

33 changes: 32 additions & 1 deletion src/pydantic_ai_lightspeed/llamastack/_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import TYPE_CHECKING, Optional

import httpx
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
from llama_stack_client import AsyncLlamaStackClient
from openai import AsyncOpenAI
from pydantic_ai import ModelProfile
from pydantic_ai.models import create_async_http_client
Expand All @@ -14,7 +16,9 @@
from pydantic_ai_lightspeed.llamastack._transport import LlamaStackLibraryTransport

if TYPE_CHECKING:
from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
from llama_stack.core.library_client import ( # pylint: disable=reimported
AsyncLlamaStackAsLibraryClient,
)

DEFAULT_BASE_URL = "http://localhost:8321/v1"

Expand Down Expand Up @@ -48,6 +52,33 @@ def model_profile(model_name: str) -> Optional[ModelProfile]:
"""Return the model profile for the named model, if available."""
return openai_model_profile(model_name)

@staticmethod
def from_llama_stack_client(
client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient,
) -> LlamaStackProvider:
"""Create a ``LlamaStackProvider`` from a Llama Stack client.

For an ``AsyncLlamaStackAsLibraryClient``, delegates to library mode.
For an ``AsyncLlamaStackClient``, extracts the base URL, API key, and
underlying HTTP client to create a server-mode provider.

Args:
client: A Llama Stack client (server or library variant).

Returns:
Configured ``LlamaStackProvider`` instance.
"""
Comment thread
asimurka marked this conversation as resolved.
if isinstance(client, AsyncLlamaStackAsLibraryClient):
return LlamaStackProvider(library_client=client)
api_key = client.api_key or "not-needed"
base = str(client.base_url).rstrip("/")
base_url = base if base.endswith("/v1") else f"{base}/v1"
return LlamaStackProvider(
base_url=base_url,
api_key=api_key,
http_client=client._client, # pylint: disable=protected-access
)

def __init__(
self,
*,
Expand Down
71 changes: 4 additions & 67 deletions src/utils/pydantic_ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,17 @@
from __future__ import annotations

import re
from typing import Any, Final, Optional, cast
from typing import Any, Final, Optional

from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient
from llama_stack_client import AsyncLlamaStackClient
from pydantic_ai.agent import Agent
from pydantic_ai.capabilities import AbstractCapability, AgentCapability
from pydantic_ai.models.openai import OpenAIResponsesModelSettings
from pydantic_ai_skills import SkillsCapability

from models.common.responses.responses_api_params import ResponsesApiParams
from models.config import SkillsConfiguration
from pydantic_ai_lightspeed.llamastack import (
LlamaStackProvider,
LlamaStackResponsesModel,
)

Expand All @@ -24,64 +22,6 @@
_BUILTIN_CAPABILITY_SERVER_SOURCE: Final[str] = "builtin"
_CAPABILITY_TOOL_TYPE: Final[str] = "tool"

_LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset(
{
"conversation",
"max_infer_iters",
"tool_choice",
"include",
"text",
"reasoning",
"prompt",
"metadata",
"max_tool_calls",
"safety_identifier",
}
)


def llama_stack_provider_from_client(
client: AsyncLlamaStackClient | AsyncLlamaStackAsLibraryClient,
) -> LlamaStackProvider:
"""Construct a Pydantic AI Llama Stack provider backed by the same client as ``/query``."""
if isinstance(client, AsyncLlamaStackAsLibraryClient):
return LlamaStackProvider(library_client=client)
api_key = client.api_key or "not-needed"
base = str(client.base_url).rstrip("/")
base_url = base if base.endswith("/v1") else f"{base}/v1"
return LlamaStackProvider(
base_url=base_url,
api_key=api_key,
http_client=client._client, # pylint: disable=protected-access
)


def _model_settings_from_responses_params(
responses_params: ResponsesApiParams,
) -> OpenAIResponsesModelSettings:
"""Map ``ResponsesApiParams`` into Pydantic AI OpenAI Responses model settings."""
payload = responses_params.model_dump(exclude_none=True)
extra_body = {k: v for k, v in payload.items() if k in _LLS_RESPONSES_EXTRA_FIELDS}
settings_dict: dict[str, Any] = {}
if extra_body:
settings_dict["extra_body"] = extra_body
if responses_params.max_output_tokens is not None:
settings_dict["max_tokens"] = responses_params.max_output_tokens
if responses_params.temperature is not None:
settings_dict["temperature"] = responses_params.temperature
if responses_params.parallel_tool_calls is not None:
settings_dict["parallel_tool_calls"] = responses_params.parallel_tool_calls
if responses_params.extra_headers:
settings_dict["extra_headers"] = dict(responses_params.extra_headers)
settings_dict["openai_store"] = responses_params.store
if responses_params.tools is not None:
settings_dict["openai_native_tools"] = responses_params.tools
if responses_params.previous_response_id is not None:
settings_dict["openai_previous_response_id"] = (
responses_params.previous_response_id
)
return cast(OpenAIResponsesModelSettings, settings_dict)


def _skills_capability(
skills_config: Optional[SkillsConfiguration],
Expand Down Expand Up @@ -239,15 +179,12 @@ def build_agent(
``Agent`` configured for ``await agent.run(...)`` (or streaming) against the same
stack configuration as ``client.responses.create(**responses_params.model_dump())``.
"""
provider = llama_stack_provider_from_client(client)
settings = _model_settings_from_responses_params(responses_params)
capabilities = _agent_capabilities(skills, no_tools=no_tools)

model = LlamaStackResponsesModel(
responses_params.model,
provider=provider,
settings=settings,
model = LlamaStackResponsesModel.from_llama_stack_client(
responses_params.model, client, responses_params=responses_params
)

return Agent(
model,
instructions=responses_params.instructions,
Expand Down
Loading
Loading