-
Notifications
You must be signed in to change notification settings - Fork 957
Expand file tree
/
Copy pathapi_models.py
More file actions
127 lines (86 loc) · 3.12 KB
/
api_models.py
File metadata and controls
127 lines (86 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from enum import Enum
from typing import Any, Optional
from openai.types.responses import ResponseInputItemParam
from pydantic import BaseModel, Field
class AIChatRoles(str, Enum):
USER = "user"
ASSISTANT = "assistant"
SYSTEM = "system"
class Message(BaseModel):
content: str
role: AIChatRoles = AIChatRoles.USER
class RetrievalMode(str, Enum):
TEXT = "text"
VECTORS = "vectors"
HYBRID = "hybrid"
class ChatRequestOverrides(BaseModel):
top: int = 3
temperature: float = 0.3
retrieval_mode: RetrievalMode = RetrievalMode.HYBRID
use_advanced_flow: bool = True
prompt_template: Optional[str] = None
seed: Optional[int] = None
class ChatRequestContext(BaseModel):
overrides: ChatRequestOverrides
class ChatRequest(BaseModel):
messages: list[ResponseInputItemParam]
context: ChatRequestContext
sessionState: Optional[Any] = None
class ItemPublic(BaseModel):
id: int
type: str
brand: str
name: str
description: str
price: float
def to_str_for_rag(self):
return f"Name:{self.name} Description:{self.description} Price:{self.price} Brand:{self.brand} Type:{self.type}"
class ItemWithDistance(ItemPublic):
distance: float
def __init__(self, **data):
super().__init__(**data)
self.distance = round(self.distance, 2)
class ThoughtStep(BaseModel):
title: str
description: Any
props: dict = {}
class RAGContext(BaseModel):
data_points: dict[int, ItemPublic]
thoughts: list[ThoughtStep]
followup_questions: Optional[list[str]] = None
class ErrorResponse(BaseModel):
error: str
class RetrievalResponse(BaseModel):
message: Message
context: RAGContext
sessionState: Optional[Any] = None
class RetrievalResponseDelta(BaseModel):
delta: Optional[Message] = None
context: Optional[RAGContext] = None
sessionState: Optional[Any] = None
class ChatParams(ChatRequestOverrides):
prompt_template: str
response_token_limit: int = 1024
enable_text_search: bool
enable_vector_search: bool
original_user_query: str
past_messages: list[ResponseInputItemParam]
class Filter(BaseModel):
column: str
comparison_operator: str
value: Any
class PriceFilter(Filter):
column: str = Field(default="price", description="The column to filter on (always 'price' for this filter)")
comparison_operator: str = Field(description="The operator for price comparison ('>', '<', '>=', '<=', '=')")
value: float = Field(description="The price value to compare against (e.g., 30.00)")
class BrandFilter(Filter):
column: str = Field(default="brand", description="The column to filter on (always 'brand' for this filter)")
comparison_operator: str = Field(description="The operator for brand comparison ('=' or '!=')")
value: str = Field(description="The brand name to compare against (e.g., 'AirStrider')")
class SearchResults(BaseModel):
query: str
"""The original search query"""
items: list[ItemPublic]
"""List of items that match the search query and filters"""
filters: list[Filter]
"""List of filters applied to the search results"""