Skip to content

Commit c54902e

Browse files
committed
feat: Vector retrieval matches tables
1 parent cf96de4 commit c54902e

4 files changed

Lines changed: 9 additions & 10 deletions

File tree

backend/alembic/versions/047_table_embedding.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import sqlalchemy as sa
1010
import sqlmodel.sql.sqltypes
1111
from sqlalchemy.dialects import postgresql
12-
import pgvector
1312

1413
# revision identifiers, used by Alembic.
1514
revision = 'c1b794a961ce'
@@ -20,7 +19,7 @@
2019

2120
def upgrade():
2221
# ### commands auto generated by Alembic - please adjust! ###
23-
op.add_column('core_table', sa.Column('embedding', pgvector.sqlalchemy.vector.VECTOR(), nullable=True))
22+
op.add_column('core_table', sa.Column('embedding', sa.Text(), nullable=True))
2423
# ### end Alembic commands ###
2524

2625

backend/apps/datasource/crud/table.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def run_fill_empty_table_embedding(session: Session):
3838
SQLBotLogUtil.info('get tables')
3939
stmt = select(CoreTable.id).where(and_(CoreTable.embedding.is_(None)))
4040
results = session.execute(stmt).scalars().all()
41-
SQLBotLogUtil.info(results)
41+
SQLBotLogUtil.info(json.dumps(results))
4242

4343
save_table_embedding(session, results)
4444

@@ -53,8 +53,8 @@ def save_table_embedding(session: Session, ids: List[int]):
5353
SQLBotLogUtil.info('start table embedding')
5454
start_time = time.time()
5555
model = EmbeddingModelCache.get_model()
56-
for id in ids:
57-
table = session.query(CoreTable).filter(CoreTable.id == id).first()
56+
for _id in ids:
57+
table = session.query(CoreTable).filter(CoreTable.id == _id).first()
5858
fields = session.query(CoreField).filter(CoreField.table_id == table.id).all()
5959

6060
schema_table = ''
@@ -82,7 +82,7 @@ def save_table_embedding(session: Session, ids: List[int]):
8282
# table_schema.append(schema_table)
8383
emb = model.embed_query(schema_table)
8484

85-
stmt = update(CoreTable).where(and_(CoreTable.id == id)).values(embedding=emb)
85+
stmt = update(CoreTable).where(and_(CoreTable.id == id)).values(embedding=json.dumps(emb))
8686
session.execute(stmt)
8787
session.commit()
8888

backend/apps/datasource/embedding/table_embedding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,13 @@ def calc_table_embedding(tables: list[dict], question: str):
5656
# results = model.embed_documents(text)
5757
# end_time = time.time()
5858
# SQLBotLogUtil.info(str(end_time - start_time))
59-
results = [item.get('embedding') if item.get('embedding') else ' ' for item in _list]
59+
results = [item.get('embedding') for item in _list]
6060

6161
q_embedding = model.embed_query(question)
6262
for index in range(len(results)):
6363
item = results[index]
64-
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, item)
64+
if item:
65+
_list[index]['cosine_similarity'] = cosine_similarity(q_embedding, json.loads(item))
6566

6667
_list.sort(key=lambda x: x['cosine_similarity'], reverse=True)
6768
_list = _list[:settings.TABLE_EMBEDDING_COUNT]

backend/apps/datasource/models/datasource.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from datetime import datetime
22
from typing import List, Optional
33

4-
from pgvector.sqlalchemy import VECTOR
54
from pydantic import BaseModel
65
from sqlalchemy import Column, Text, BigInteger, DateTime, Identity
76
from sqlalchemy.dialects.postgresql import JSONB
@@ -32,7 +31,7 @@ class CoreTable(SQLModel, table=True):
3231
table_name: str = Field(sa_column=Column(Text))
3332
table_comment: str = Field(sa_column=Column(Text))
3433
custom_comment: str = Field(sa_column=Column(Text))
35-
embedding: Optional[List[float]] = Field(sa_column=Column(VECTOR(), nullable=True))
34+
embedding: str = Field(sa_column=Column(Text, nullable=True))
3635

3736

3837
class CoreField(SQLModel, table=True):

0 commit comments

Comments
 (0)