|
66 | 66 | "import asyncpg\n", |
67 | 67 | "import uuid\n", |
68 | 68 | "from pgvector.asyncpg import register_vector\n", |
69 | | - "from typing import (List, Optional)\n", |
| 69 | + "from typing import (List, Optional, Union, Dict, Tuple)\n", |
70 | 70 | "import json " |
71 | 71 | ] |
72 | 72 | }, |
|
217 | 217 | " return \"CREATE INDEX {index_name} ON {table_name} USING ivfflat ({column_name} {index_method}) WITH (lists = {num_lists});\"\\\n", |
218 | 218 | " .format(index_name=self._get_embedding_index_name(), table_name=self._quote_ident(self.table_name), column_name=self._quote_ident(column_name), index_method=index_method, num_lists=num_lists)\n", |
219 | 219 | "\n", |
220 | | - " def search_query(self, query_embedding: List[float], k: int=10, filter: Optional[dict] = None):\n", |
| 220 | + " def search_query(self, query_embedding: List[float], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None) -> Tuple[str, List]:\n", |
221 | 221 | " \"\"\"\n", |
222 | 222 | " Generates a similarity query.\n", |
223 | 223 | "\n", |
|
232 | 232 | " params = []\n", |
233 | 233 | " distance = \"embedding {op} ${index}\".format(op=self.distance_type, index=len(params)+1)\n", |
234 | 234 | " params = params + [query_embedding]\n", |
235 | | - " \n", |
236 | | - " where = \"TRUE\"\n", |
237 | | - " if filter != None:\n", |
| 235 | + "\n", |
| 236 | + " if isinstance(filter, dict):\n", |
238 | 237 | " where = \"metadata @> ${index}\".format(index=len(params)+1)\n", |
239 | 238 | " json_object = json.dumps(filter)\n", |
240 | 239 | " params = params + [json_object]\n", |
| 240 | + " elif isinstance(filter, list):\n", |
| 241 | + " any_params = []\n", |
| 242 | + " for idx, filter_dict in enumerate(filter, start=len(params) + 1):\n", |
| 243 | + " any_params.append(json.dumps(filter_dict))\n", |
| 244 | + " where = \"metadata @> ANY(${index}::jsonb[])\".format(index=len(params) + 1)\n", |
| 245 | + " params = params + [any_params]\n", |
| 246 | + " else:\n", |
| 247 | + " where = \"TRUE\"\n", |
| 248 | + " \n", |
241 | 249 | " query = '''\n", |
242 | 250 | " SELECT\n", |
243 | 251 | " id, metadata, contents, embedding, {distance} as distance\n", |
|
261 | 269 | "text/markdown": [ |
262 | 270 | "---\n", |
263 | 271 | "\n", |
264 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L79){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 272 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L87){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
265 | 273 | "\n", |
266 | 274 | "### QueryBuilder.get_create_query\n", |
267 | 275 | "\n", |
|
275 | 283 | "text/plain": [ |
276 | 284 | "---\n", |
277 | 285 | "\n", |
278 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L79){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 286 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L87){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
279 | 287 | "\n", |
280 | 288 | "### QueryBuilder.get_create_query\n", |
281 | 289 | "\n", |
|
443 | 451 | " async def search(self, \n", |
444 | 452 | " query_embedding: List[float], # vector to search for\n", |
445 | 453 | " k: int=10, # The number of nearest neighbors to retrieve. Default is 10.\n", |
446 | | - " filter: Optional[dict] = None): # A filter for metadata. Default is None.\n", |
| 454 | + " filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None): # A filter for metadata. Default is None.\n", |
447 | 455 | " \"\"\"\n", |
448 | 456 | " Retrieves similar records using a similarity query.\n", |
449 | 457 | "\n", |
|
465 | 473 | "text/markdown": [ |
466 | 474 | "---\n", |
467 | 475 | "\n", |
468 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L229){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 476 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L248){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
469 | 477 | "\n", |
470 | 478 | "### Async.create_tables\n", |
471 | 479 | "\n", |
|
479 | 487 | "text/plain": [ |
480 | 488 | "---\n", |
481 | 489 | "\n", |
482 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L229){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 490 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L248){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
483 | 491 | "\n", |
484 | 492 | "### Async.create_tables\n", |
485 | 493 | "\n", |
|
510 | 518 | "text/markdown": [ |
511 | 519 | "---\n", |
512 | 520 | "\n", |
513 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L229){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 521 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L248){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
514 | 522 | "\n", |
515 | 523 | "### Async.create_tables\n", |
516 | 524 | "\n", |
|
524 | 532 | "text/plain": [ |
525 | 533 | "---\n", |
526 | 534 | "\n", |
527 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L229){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 535 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L248){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
528 | 536 | "\n", |
529 | 537 | "### Async.create_tables\n", |
530 | 538 | "\n", |
|
555 | 563 | "text/markdown": [ |
556 | 564 | "---\n", |
557 | 565 | "\n", |
558 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L279){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 566 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L311){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
559 | 567 | "\n", |
560 | 568 | "### Async.search\n", |
561 | 569 | "\n", |
562 | 570 | "> Async.search (query_embedding:List[float], k:int=10,\n", |
563 | | - "> filter:Optional[dict]=None)\n", |
| 571 | + "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n", |
| 572 | + "> ne)\n", |
564 | 573 | "\n", |
565 | 574 | "Retrieves similar records using a similarity query.\n", |
566 | 575 | "\n", |
|
570 | 579 | "text/plain": [ |
571 | 580 | "---\n", |
572 | 581 | "\n", |
573 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L279){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 582 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L311){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
574 | 583 | "\n", |
575 | 584 | "### Async.search\n", |
576 | 585 | "\n", |
577 | 586 | "> Async.search (query_embedding:List[float], k:int=10,\n", |
578 | | - "> filter:Optional[dict]=None)\n", |
| 587 | + "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=No\n", |
| 588 | + "> ne)\n", |
579 | 589 | "\n", |
580 | 590 | "Retrieves similar records using a similarity query.\n", |
581 | 591 | "\n", |
|
660 | 670 | "rec = await vec.search([1.0, 2.0], k=4, filter={\"key_1\":\"val_1\", \"key_2\":\"val_3\"})\n", |
661 | 671 | "assert len(rec) == 0\n", |
662 | 672 | "\n", |
| 673 | + "rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n", |
| 674 | + "assert len(rec) == 2\n", |
| 675 | + "\n", |
| 676 | + "rec = await vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}, {\"no such key\": \"no such val\"}])\n", |
| 677 | + "assert len(rec) == 2\n", |
| 678 | + "\n", |
663 | 679 | "try:\n", |
664 | 680 | " # can't upsert using both keys and dictionaries\n", |
665 | 681 | " await vec.upsert([ \\\n", |
|
887 | 903 | " with conn.cursor() as cur:\n", |
888 | 904 | " cur.execute(query)\n", |
889 | 905 | "\n", |
890 | | - " def search(self, query_embedding: List[float], k: int=10, filter: Optional[dict] = None):\n", |
| 906 | + " def search(self, query_embedding: List[float], k: int=10, filter: Optional[Union[Dict[str, str], List[Dict[str, str]]]] = None):\n", |
891 | 907 | " \"\"\"\n", |
892 | 908 | " Retrieves similar records using a similarity query.\n", |
893 | 909 | "\n", |
|
917 | 933 | "text/markdown": [ |
918 | 934 | "---\n", |
919 | 935 | "\n", |
920 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L398){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 936 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L438){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
921 | 937 | "\n", |
922 | 938 | "### Sync.create_tables\n", |
923 | 939 | "\n", |
|
931 | 947 | "text/plain": [ |
932 | 948 | "---\n", |
933 | 949 | "\n", |
934 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L398){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 950 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L438){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
935 | 951 | "\n", |
936 | 952 | "### Sync.create_tables\n", |
937 | 953 | "\n", |
|
962 | 978 | "text/markdown": [ |
963 | 979 | "---\n", |
964 | 980 | "\n", |
965 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L382){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 981 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L419){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
966 | 982 | "\n", |
967 | 983 | "### Sync.upsert\n", |
968 | 984 | "\n", |
|
979 | 995 | "text/plain": [ |
980 | 996 | "---\n", |
981 | 997 | "\n", |
982 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L382){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 998 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L419){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
983 | 999 | "\n", |
984 | 1000 | "### Sync.upsert\n", |
985 | 1001 | "\n", |
|
1013 | 1029 | "text/markdown": [ |
1014 | 1030 | "---\n", |
1015 | 1031 | "\n", |
1016 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L453){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 1032 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L507){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
1017 | 1033 | "\n", |
1018 | 1034 | "### Sync.search\n", |
1019 | 1035 | "\n", |
1020 | 1036 | "> Sync.search (query_embedding:List[float], k:int=10,\n", |
1021 | | - "> filter:Optional[dict]=None)\n", |
| 1037 | + "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n", |
| 1038 | + "> e)\n", |
1022 | 1039 | "\n", |
1023 | 1040 | "Retrieves similar records using a similarity query.\n", |
1024 | 1041 | "\n", |
|
1033 | 1050 | "text/plain": [ |
1034 | 1051 | "---\n", |
1035 | 1052 | "\n", |
1036 | | - "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L453){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
| 1053 | + "[source](https://github.com/timescale/python-vector/blob/main/timescale_vector/client.py#L507){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n", |
1037 | 1054 | "\n", |
1038 | 1055 | "### Sync.search\n", |
1039 | 1056 | "\n", |
1040 | 1057 | "> Sync.search (query_embedding:List[float], k:int=10,\n", |
1041 | | - "> filter:Optional[dict]=None)\n", |
| 1058 | + "> filter:Union[Dict[str,str],List[Dict[str,str]],NoneType]=Non\n", |
| 1059 | + "> e)\n", |
1042 | 1060 | "\n", |
1043 | 1061 | "Retrieves similar records using a similarity query.\n", |
1044 | 1062 | "\n", |
|
1131 | 1149 | "rec = vec.search([1.0, 2.0], k=4, filter={\"key_1\":\"val_1\", \"key_2\":\"val_3\"})\n", |
1132 | 1150 | "assert len(rec) == 0\n", |
1133 | 1151 | "\n", |
| 1152 | + "rec = vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}])\n", |
| 1153 | + "assert len(rec) == 2\n", |
| 1154 | + "\n", |
| 1155 | + "rec = vec.search([1.0, 2.0], k=4, filter=[{\"key_1\":\"val_1\"}, {\"key2\":\"val2\"}, {\"no such key\": \"no such val\"}])\n", |
| 1156 | + "assert len(rec) == 2\n", |
| 1157 | + "\n", |
1134 | 1158 | "try:\n", |
1135 | 1159 | " # can't upsert using both keys and dictionaries\n", |
1136 | 1160 | " await vec.upsert([ \\\n", |
|
0 commit comments