|
68 | 68 | "import asyncpg\n", |
69 | 69 | "import uuid\n", |
70 | 70 | "from pgvector.asyncpg import register_vector\n", |
71 | | - "from typing import (List, Optional, Union, Dict, Tuple)\n", |
72 | | - "import json " |
| 71 | + "from typing import (List, Optional, Union, Dict, Tuple, Any)\n", |
| 72 | + "import json\n", |
| 73 | + "import numpy as np " |
73 | 74 | ] |
74 | 75 | }, |
75 | 76 | { |
|
192 | 193 | " return (query, [id])\n", |
193 | 194 | "\n", |
194 | 195 | " def delete_by_metadata_query (self, filter: Union[Dict[str, str], List[Dict[str, str]]]) -> Tuple[str, List]:\n", |
195 | | - " params = []\n", |
| 196 | + " params: List[Any] = []\n", |
196 | 197 | " (where, params) = self._where_clause_for_filter(params, filter)\n", |
197 | 198 | " query = \"DELETE FROM {table_name} WHERE {where};\".format(table_name=self._quote_ident(self.table_name), where=where)\n", |
198 | 199 | " return (query, params) \n", |
|
247 | 248 | "\n", |
248 | 249 | " return (where, params) \n", |
249 | 250 | "\n", |
250 | | - " def search_query(self, query_embedding: Optional[List[float]], limit: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n", |
| 251 | + " def search_query(self, query_embedding: Optional[Union[List[float], np.ndarray]], limit: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n", |
251 | 252 | " \"\"\"\n", |
252 | 253 | " Generates a similarity query.\n", |
253 | 254 | "\n", |
|
259 | 260 | " Returns:\n", |
260 | 261 | " Tuple[str, List]: A tuple containing the query and parameters.\n", |
261 | 262 | " \"\"\"\n", |
262 | | - " params = []\n", |
| 263 | + " params: List[Any] = []\n", |
263 | 264 | " if query_embedding is not None:\n", |
264 | 265 | " distance = \"embedding {op} ${index}\".format(op=self.distance_type, index=len(params)+1)\n", |
265 | 266 | " params = params + [query_embedding]\n", |
|
816 | 817 | "source": [ |
817 | 818 | "#| export\n", |
818 | 819 | "class Sync:\n", |
819 | | - " translated_queries = {}\n", |
| 820 | + " translated_queries: Dict[str, str] = {}\n", |
820 | 821 | " \n", |
821 | 822 | " def __init__(\n", |
822 | 823 | " self,\n", |
|
1033 | 1034 | " List: List of similar records.\n", |
1034 | 1035 | " \"\"\"\n", |
1035 | 1036 | " if query_embedding is not None:\n", |
1036 | | - " query_embedding = np.array(query_embedding)\n", |
| 1037 | + " query_embedding_np = np.array(query_embedding)\n", |
| 1038 | + " else:\n", |
| 1039 | + " query_embedding_np = None \n", |
1037 | 1040 | " \n", |
1038 | | - " (query, params) = self.builder.search_query(query_embedding, limit, filter)\n", |
| 1041 | + " (query, params) = self.builder.search_query(query_embedding_np, limit, filter)\n", |
1039 | 1042 | " query, params = self._translate_to_pyformat(query, params)\n", |
1040 | 1043 | " with self.connect() as conn:\n", |
1041 | 1044 | " with conn.cursor() as cur:\n", |
|
0 commit comments