|
8 | 8 | from sqlmodel import select |
9 | 9 |
|
10 | 10 | from apps.datasource.crud.permission import get_column_permission_fields, get_row_permission_filters, is_normal_user |
11 | | -from apps.datasource.embedding.table_embedding import get_table_embedding |
| 11 | +from apps.datasource.embedding.table_embedding import get_table_embedding, calc_table_embedding |
12 | 12 | from apps.datasource.utils.utils import aes_decrypt |
13 | 13 | from apps.db.constant import DB |
14 | 14 | from apps.db.db import get_tables, get_fields, exec_sql, check_connection |
15 | 15 | from apps.db.engine import get_engine_config, get_engine_conn |
16 | 16 | from common.core.config import settings |
17 | 17 | from common.core.deps import SessionDep, CurrentUser, Trans |
| 18 | +from common.utils.embedding_threads import run_save_table_embeddings |
18 | 19 | from common.utils.utils import deepcopy_ignore_extra |
19 | 20 | from .table import get_tables_by_ds_id |
20 | 21 | from ..crud.field import delete_field_by_ds_id, update_field |
@@ -194,6 +195,9 @@ def sync_table(session: SessionDep, ds: CoreDatasource, tables: List[CoreTable]) |
194 | 195 | session.query(CoreField).filter(CoreField.ds_id == ds.id).delete(synchronize_session=False) |
195 | 196 | session.commit() |
196 | 197 |
|
| 198 | + # do table embedding |
| 199 | + run_save_table_embeddings(id_list) |
| 200 | + |
197 | 201 |
|
198 | 202 | def sync_fields(session: SessionDep, ds: CoreDatasource, table: CoreTable, fields: List[ColumnSchema]): |
199 | 203 | id_list = [] |
@@ -232,14 +236,23 @@ def update_table_and_fields(session: SessionDep, data: TableObj): |
232 | 236 | for field in data.fields: |
233 | 237 | update_field(session, field) |
234 | 238 |
|
| 239 | + # do table embedding |
| 240 | + run_save_table_embeddings([data.table.id]) |
| 241 | + |
235 | 242 |
|
236 | 243 | def updateTable(session: SessionDep, table: CoreTable): |
237 | 244 | update_table(session, table) |
238 | 245 |
|
| 246 | + # do table embedding |
| 247 | + run_save_table_embeddings([table.id]) |
| 248 | + |
239 | 249 |
|
240 | 250 | def updateField(session: SessionDep, field: CoreField): |
241 | 251 | update_field(session, field) |
242 | 252 |
|
| 253 | + # do table embedding |
| 254 | + run_save_table_embeddings([field.table_id]) |
| 255 | + |
243 | 256 |
|
244 | 257 | def preview(session: SessionDep, current_user: CurrentUser, id: int, data: TableObj): |
245 | 258 | ds = session.query(CoreDatasource).filter(CoreDatasource.id == id).first() |
@@ -398,13 +411,13 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat |
398 | 411 | schema_table += ",\n".join(field_list) |
399 | 412 | schema_table += '\n]\n' |
400 | 413 |
|
401 | | - t_obj = {"id": obj.table.id, "schema_table": schema_table} |
| 414 | + t_obj = {"id": obj.table.id, "schema_table": schema_table, "embedding": obj.table.embedding} |
402 | 415 | tables.append(t_obj) |
403 | 416 | all_tables.append(t_obj) |
404 | 417 |
|
405 | 418 | # do table embedding |
406 | 419 | if embedding and tables and settings.TABLE_EMBEDDING_ENABLED: |
407 | | - tables = get_table_embedding(tables, question) |
| 420 | + tables = calc_table_embedding(tables, question) |
408 | 421 | # splice schema |
409 | 422 | if tables: |
410 | 423 | for s in tables: |
|
0 commit comments