Skip to content

Commit bf70078

Browse files
committed
Fix scoring: per-keyword min-cosine replaces merged vector averaging
1 parent e81a796 commit bf70078

File tree

1 file changed

+99
-43
lines changed

1 file changed

+99
-43
lines changed

src/store/store.c

Lines changed: 99 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -4606,81 +4606,84 @@ int cbm_store_vector_search(cbm_store_t *s, const char *project, const char **ke
46064606
return CBM_STORE_ERR;
46074607
}
46084608

4609-
/* Build merged query vector from enriched token vectors stored in
4610-
* token_vectors table. Falls back to base RI vectors if no enriched
4611-
* vectors are available. */
4612-
enum { VEC_DIM = 768, SPARSE_NNZE = 8, RI_SEED = 0x52494E44, IDF_SCALE = 1000 };
4613-
float query_f[VEC_DIM];
4614-
memset(query_f, 0, sizeof(query_f));
4615-
4616-
for (int k = 0; k < keyword_count; k++) {
4609+
/* Per-keyword scoring: score each keyword independently against each
4610+
* node vector, then combine using min(cosine_k) across keywords.
4611+
* This ensures ALL keywords must be relevant, not just the average.
4612+
*
4613+
* Step 1: Build per-keyword int8 query vectors.
4614+
* Step 2: For each node, compute cosine per keyword, take min.
4615+
* Step 3: Rank by min-score. */
4616+
enum { VEC_DIM = 768, SPARSE_NNZE = 8, RI_SEED = 0x52494E44, IDF_SCALE = 1000,
4617+
MAX_KW = 32, SCAN_BUF_INIT = 1024 };
4618+
4619+
/* Build per-keyword vectors */
4620+
int actual_kw = 0;
4621+
int8_t kw_vecs[MAX_KW][VEC_DIM];
4622+
4623+
for (int k = 0; k < keyword_count && actual_kw < MAX_KW; k++) {
46174624
if (!keywords[k] || !keywords[k][0]) {
46184625
continue;
46194626
}
4620-
/* Try enriched vector from token_vectors table first */
4627+
float kw_f[VEC_DIM];
4628+
memset(kw_f, 0, sizeof(kw_f));
4629+
bool found = false;
4630+
4631+
/* Try enriched vector from token_vectors table */
46214632
sqlite3_stmt *tv_stmt = NULL;
46224633
const char *tv_sql = "SELECT vector, idf FROM token_vectors"
46234634
" WHERE project = ?1 AND token = ?2 LIMIT 1";
4624-
bool found = false;
46254635
if (sqlite3_prepare_v2(s->db, tv_sql, -1, &tv_stmt, NULL) == SQLITE_OK) {
46264636
sqlite3_bind_text(tv_stmt, SKIP_ONE, project, -1, SQLITE_STATIC);
46274637
sqlite3_bind_text(tv_stmt, ST_COL_2, keywords[k], -1, SQLITE_STATIC);
46284638
if (sqlite3_step(tv_stmt) == SQLITE_ROW) {
46294639
const int8_t *vec = (const int8_t *)sqlite3_value_blob(
46304640
sqlite3_column_value(tv_stmt, 0));
46314641
int vec_len = sqlite3_column_bytes(tv_stmt, 0);
4632-
float idf = (float)sqlite3_column_int(tv_stmt, SKIP_ONE) / IDF_SCALE;
46334642
if (vec && vec_len == VEC_DIM) {
46344643
for (int d = 0; d < VEC_DIM; d++) {
4635-
query_f[d] += idf * ((float)vec[d] / 127.0f);
4644+
kw_f[d] = (float)vec[d] / 127.0f;
46364645
}
46374646
found = true;
46384647
}
46394648
}
46404649
sqlite3_finalize(tv_stmt);
46414650
}
4642-
/* Fallback: base sparse random vector if no enriched vector found */
46434651
if (!found) {
46444652
uint64_t seed = XXH3_64bits(keywords[k], strlen(keywords[k]));
46454653
for (int i = 0; i < SPARSE_NNZE; i++) {
46464654
uint64_t h = XXH3_64bits_withSeed(&i, sizeof(i), seed + RI_SEED);
46474655
int pos = (int)(h % VEC_DIM);
46484656
float sign = (h & SKIP_ONE) ? 1.0f : -1.0f;
4649-
query_f[pos] += sign;
4657+
kw_f[pos] += sign;
46504658
}
46514659
}
4652-
}
4653-
4654-
/* Normalize */
4655-
float mag = 0.0f;
4656-
for (int i = 0; i < VEC_DIM; i++) {
4657-
mag += query_f[i] * query_f[i];
4658-
}
4659-
mag = sqrtf(mag);
4660-
if (mag < 1e-10f) {
4661-
return CBM_STORE_OK; /* zero vector — no results */
4662-
}
4663-
float inv = 1.0f / mag;
4664-
4665-
/* Int8 quantize */
4666-
int8_t query_i8[VEC_DIM];
4667-
for (int i = 0; i < VEC_DIM; i++) {
4668-
float v = query_f[i] * inv * 127.0f;
4669-
if (v > 127.0f) {
4670-
v = 127.0f;
4660+
/* Normalize + quantize */
4661+
float mag = 0.0f;
4662+
for (int d = 0; d < VEC_DIM; d++) {
4663+
mag += kw_f[d] * kw_f[d];
4664+
}
4665+
mag = sqrtf(mag);
4666+
if (mag < 1e-10f) {
4667+
continue;
46714668
}
4672-
if (v < -127.0f) {
4673-
v = -127.0f;
4669+
float inv = 1.0f / mag;
4670+
for (int d = 0; d < VEC_DIM; d++) {
4671+
float v = kw_f[d] * inv * 127.0f;
4672+
kw_vecs[actual_kw][d] = (int8_t)(v > 127.0f ? 127.0f : (v < -127.0f ? -127.0f : v));
46744673
}
4675-
query_i8[i] = (int8_t)v;
4674+
actual_kw++;
46764675
}
46774676

4678-
/* Query: cosine similarity JOIN with nodes for metadata.
4679-
* We use a subquery to compute the score first, then join for metadata
4680-
* — this avoids SQLite computing the cosine for rows we won't return. */
4677+
if (actual_kw == 0) {
4678+
return CBM_STORE_OK;
4679+
}
4680+
4681+
/* Scan all node vectors, compute per-keyword cosine, take min.
4682+
* We use the FIRST keyword as the SQL sort (for top-K pre-filter),
4683+
* then re-score with min across all keywords in C. */
46814684
const char *sql =
46824685
"SELECT n.id, n.name, n.qualified_name, n.file_path, n.label,"
4683-
" cbm_cosine_i8(v.vector, ?1) as score"
4686+
" cbm_cosine_i8(v.vector, ?1) as score, v.vector"
46844687
" FROM node_vectors v"
46854688
" INNER JOIN nodes n ON n.id = v.node_id"
46864689
" WHERE v.project = ?2"
@@ -4695,9 +4698,11 @@ int cbm_store_vector_search(cbm_store_t *s, const char *project, const char **ke
46954698
return CBM_STORE_ERR;
46964699
}
46974700

4698-
sqlite3_bind_blob(stmt, SKIP_ONE, query_i8, VEC_DIM, SQLITE_STATIC);
4701+
/* Use first keyword for SQL pre-filter, fetch more candidates for re-ranking */
4702+
int fetch_limit = (limit > 0 ? limit : CBM_SZ_16) * ST_COL_5;
4703+
sqlite3_bind_blob(stmt, SKIP_ONE, kw_vecs[0], VEC_DIM, SQLITE_STATIC);
46994704
sqlite3_bind_text(stmt, ST_COL_2, project, -1, SQLITE_STATIC);
4700-
sqlite3_bind_int(stmt, ST_COL_3, limit > 0 ? limit : CBM_SZ_16);
4705+
sqlite3_bind_int(stmt, ST_COL_3, fetch_limit);
47014706

47024707
cbm_vector_result_t *results = NULL;
47034708
int count = 0;
@@ -4722,11 +4727,62 @@ int cbm_store_vector_search(cbm_store_t *s, const char *project, const char **ke
47224727
results[count].qualified_name = qn ? strdup(qn) : strdup("");
47234728
results[count].file_path = fp ? strdup(fp) : strdup("");
47244729
results[count].label = label ? strdup(label) : strdup("");
4725-
results[count].score = sqlite3_column_double(stmt, ST_COL_5);
4730+
4731+
/* Compute per-keyword min-score for this node.
4732+
* The SQL pre-filtered by first keyword; now re-score with ALL keywords. */
4733+
const void *node_vec = sqlite3_column_blob(stmt, ST_COL_6);
4734+
int node_vec_len = sqlite3_column_bytes(stmt, ST_COL_6);
4735+
double min_score = 1.0;
4736+
if (node_vec && node_vec_len == VEC_DIM) {
4737+
const int8_t *nv = (const int8_t *)node_vec;
4738+
for (int k = 0; k < actual_kw; k++) {
4739+
/* Inline int8 cosine for speed */
4740+
int32_t dot = 0;
4741+
int32_t ma = 0;
4742+
int32_t mb = 0;
4743+
for (int d = 0; d < VEC_DIM; d++) {
4744+
dot += (int32_t)kw_vecs[k][d] * (int32_t)nv[d];
4745+
ma += (int32_t)kw_vecs[k][d] * (int32_t)kw_vecs[k][d];
4746+
mb += (int32_t)nv[d] * (int32_t)nv[d];
4747+
}
4748+
double denom = sqrt((double)ma) * sqrt((double)mb);
4749+
double cos_k = denom > 1e-10 ? (double)dot / denom : 0.0;
4750+
if (cos_k < min_score) {
4751+
min_score = cos_k;
4752+
}
4753+
}
4754+
} else {
4755+
min_score = 0.0;
4756+
}
4757+
results[count].score = min_score;
47264758
count++;
47274759
}
47284760

47294761
sqlite3_finalize(stmt);
4762+
4763+
/* Re-sort by min-score (SQL sorted by first keyword only) */
4764+
for (int i = 0; i < count - SKIP_ONE; i++) {
4765+
for (int j = i + SKIP_ONE; j < count; j++) {
4766+
if (results[j].score > results[i].score) {
4767+
cbm_vector_result_t tmp = results[i];
4768+
results[i] = results[j];
4769+
results[j] = tmp;
4770+
}
4771+
}
4772+
}
4773+
4774+
/* Trim to requested limit */
4775+
int final_limit = limit > 0 ? limit : CBM_SZ_16;
4776+
if (count > final_limit) {
4777+
for (int i = final_limit; i < count; i++) {
4778+
free(results[i].name);
4779+
free(results[i].qualified_name);
4780+
free(results[i].file_path);
4781+
free(results[i].label);
4782+
}
4783+
count = final_limit;
4784+
}
4785+
47304786
*out = results;
47314787
*out_count = count;
47324788
return CBM_STORE_OK;

0 commit comments

Comments
 (0)