|
| 1 | +/// Reciprocal Rank Fusion (RRF) engine. |
| 2 | +/// |
| 3 | +/// Merges ranked results from multiple search lanes (e.g. semantic HNSW |
| 4 | +/// and FTS5 keyword search) into a single ranked list using the RRF formula: |
| 5 | +/// |
| 6 | +/// rrf_score = sum( weight_i / (k + rank_i) ) |
| 7 | +
|
| 8 | +/// A ranked result from a single search lane. |
| 9 | +pub struct RankedResult { |
| 10 | + pub file_path: String, |
| 11 | + pub file_id: i64, |
| 12 | + pub score: f64, |
| 13 | + pub heading: Option<String>, |
| 14 | + pub snippet: String, |
| 15 | + pub docid: Option<String>, |
| 16 | +} |
| 17 | + |
| 18 | +/// A fused result after RRF merging across lanes. |
| 19 | +pub struct FusedResult { |
| 20 | + pub file_path: String, |
| 21 | + pub file_id: i64, |
| 22 | + pub rrf_score: f64, |
| 23 | + pub heading: Option<String>, |
| 24 | + pub snippet: String, |
| 25 | + pub docid: Option<String>, |
| 26 | + pub lane_contributions: Vec<LaneContribution>, |
| 27 | +} |
| 28 | + |
| 29 | +/// Per-lane contribution details for --explain output. |
| 30 | +pub struct LaneContribution { |
| 31 | + pub lane_name: String, |
| 32 | + pub rank: usize, |
| 33 | + pub raw_score: f64, |
| 34 | + pub weighted_contribution: f64, |
| 35 | +} |
| 36 | + |
| 37 | +use std::collections::HashMap; |
| 38 | + |
| 39 | +/// Fuse ranked results from multiple search lanes using Reciprocal Rank Fusion. |
| 40 | +/// |
| 41 | +/// Each lane is a tuple of `(lane_name, results, weight)`. |
| 42 | +/// Results are grouped by `file_path` (file-level deduplication). |
| 43 | +/// The best snippet/heading per file is kept from the highest-ranked lane. |
| 44 | +/// |
| 45 | +/// `k` is the RRF constant (typically 60). |
| 46 | +pub fn rrf_fuse( |
| 47 | + lanes: &[(&str, &[RankedResult], f64)], |
| 48 | + k: usize, |
| 49 | +) -> Vec<FusedResult> { |
| 50 | + // Track per-file: rrf_score, best snippet info, lane contributions |
| 51 | + struct Accumulator { |
| 52 | + file_path: String, |
| 53 | + file_id: i64, |
| 54 | + rrf_score: f64, |
| 55 | + heading: Option<String>, |
| 56 | + snippet: String, |
| 57 | + docid: Option<String>, |
| 58 | + best_rank: usize, // lowest rank seen (for picking best snippet) |
| 59 | + lane_contributions: Vec<LaneContribution>, |
| 60 | + } |
| 61 | + |
| 62 | + let mut acc_map: HashMap<String, Accumulator> = HashMap::new(); |
| 63 | + |
| 64 | + for &(lane_name, results, weight) in lanes { |
| 65 | + for (idx, r) in results.iter().enumerate() { |
| 66 | + let rank = idx + 1; // 1-based |
| 67 | + let contribution = weight / (k as f64 + rank as f64); |
| 68 | + |
| 69 | + let acc = acc_map.entry(r.file_path.clone()).or_insert_with(|| Accumulator { |
| 70 | + file_path: r.file_path.clone(), |
| 71 | + file_id: r.file_id, |
| 72 | + rrf_score: 0.0, |
| 73 | + heading: r.heading.clone(), |
| 74 | + snippet: r.snippet.clone(), |
| 75 | + docid: r.docid.clone(), |
| 76 | + best_rank: rank, |
| 77 | + lane_contributions: Vec::new(), |
| 78 | + }); |
| 79 | + |
| 80 | + acc.rrf_score += contribution; |
| 81 | + |
| 82 | + // Keep snippet from the best-ranked appearance |
| 83 | + if rank < acc.best_rank { |
| 84 | + acc.best_rank = rank; |
| 85 | + acc.heading = r.heading.clone(); |
| 86 | + acc.snippet = r.snippet.clone(); |
| 87 | + if r.docid.is_some() { |
| 88 | + acc.docid = r.docid.clone(); |
| 89 | + } |
| 90 | + } |
| 91 | + |
| 92 | + acc.lane_contributions.push(LaneContribution { |
| 93 | + lane_name: lane_name.to_string(), |
| 94 | + rank, |
| 95 | + raw_score: r.score, |
| 96 | + weighted_contribution: contribution, |
| 97 | + }); |
| 98 | + } |
| 99 | + } |
| 100 | + |
| 101 | + let mut results: Vec<FusedResult> = acc_map |
| 102 | + .into_values() |
| 103 | + .map(|a| FusedResult { |
| 104 | + file_path: a.file_path, |
| 105 | + file_id: a.file_id, |
| 106 | + rrf_score: a.rrf_score, |
| 107 | + heading: a.heading, |
| 108 | + snippet: a.snippet, |
| 109 | + docid: a.docid, |
| 110 | + lane_contributions: a.lane_contributions, |
| 111 | + }) |
| 112 | + .collect(); |
| 113 | + |
| 114 | + // Sort by rrf_score descending |
| 115 | + results.sort_by(|a, b| b.rrf_score.partial_cmp(&a.rrf_score).unwrap_or(std::cmp::Ordering::Equal)); |
| 116 | + |
| 117 | + results |
| 118 | +} |
| 119 | + |
| 120 | +/// Format explain output for a single fused result. |
| 121 | +pub fn format_explain(result: &FusedResult) -> String { |
| 122 | + let mut out = format!(" RRF: {:.4}\n", result.rrf_score); |
| 123 | + for lc in &result.lane_contributions { |
| 124 | + out.push_str(&format!( |
| 125 | + " {}: rank #{}, raw {:.2}, +{:.4}\n", |
| 126 | + lc.lane_name, lc.rank, lc.raw_score, lc.weighted_contribution, |
| 127 | + )); |
| 128 | + } |
| 129 | + out |
| 130 | +} |
| 131 | + |
| 132 | +#[cfg(test)] |
| 133 | +mod tests { |
| 134 | + use super::*; |
| 135 | + |
| 136 | + fn make_result(file_path: &str, score: f64) -> RankedResult { |
| 137 | + RankedResult { |
| 138 | + file_path: file_path.to_string(), |
| 139 | + file_id: 0, |
| 140 | + score, |
| 141 | + heading: Some(format!("heading for {}", file_path)), |
| 142 | + snippet: format!("snippet for {}", file_path), |
| 143 | + docid: None, |
| 144 | + } |
| 145 | + } |
| 146 | + |
| 147 | + #[test] |
| 148 | + fn test_rrf_basic() { |
| 149 | + // Item appearing in both lanes should rank highest |
| 150 | + let semantic = vec![ |
| 151 | + make_result("both.md", 0.87), |
| 152 | + make_result("sem_only.md", 0.75), |
| 153 | + ]; |
| 154 | + let fts = vec![ |
| 155 | + make_result("fts_only.md", 5.0), |
| 156 | + make_result("both.md", 3.2), |
| 157 | + ]; |
| 158 | + |
| 159 | + let fused = rrf_fuse(&[("semantic", &semantic, 1.0), ("fts", &fts, 1.0)], 60); |
| 160 | + |
| 161 | + assert_eq!(fused.len(), 3); |
| 162 | + // "both.md" should be first because it appears in both lanes |
| 163 | + assert_eq!(fused[0].file_path, "both.md"); |
| 164 | + |
| 165 | + // Verify the RRF score for "both.md": |
| 166 | + // semantic rank 1: 1.0 / (60 + 1) = 0.01639... |
| 167 | + // fts rank 2: 1.0 / (60 + 2) = 0.01613... |
| 168 | + // total = 0.03252... |
| 169 | + let expected = 1.0 / 61.0 + 1.0 / 62.0; |
| 170 | + assert!((fused[0].rrf_score - expected).abs() < 1e-10); |
| 171 | + |
| 172 | + // Both single-lane items should have lower scores |
| 173 | + assert!(fused[0].rrf_score > fused[1].rrf_score); |
| 174 | + assert!(fused[0].rrf_score > fused[2].rrf_score); |
| 175 | + |
| 176 | + // "both.md" should have 2 lane contributions |
| 177 | + assert_eq!(fused[0].lane_contributions.len(), 2); |
| 178 | + } |
| 179 | + |
| 180 | + #[test] |
| 181 | + fn test_rrf_weighted() { |
| 182 | + // FTS weighted 3x should make FTS-only item win over semantic-only item |
| 183 | + let semantic = vec![ |
| 184 | + make_result("sem.md", 0.95), |
| 185 | + ]; |
| 186 | + let fts = vec![ |
| 187 | + make_result("fts.md", 8.0), |
| 188 | + ]; |
| 189 | + |
| 190 | + let fused = rrf_fuse(&[("semantic", &semantic, 1.0), ("fts", &fts, 3.0)], 60); |
| 191 | + |
| 192 | + assert_eq!(fused.len(), 2); |
| 193 | + // FTS item at rank 1 with weight 3.0: 3.0 / 61 = 0.04918... |
| 194 | + // Semantic item at rank 1 with weight 1.0: 1.0 / 61 = 0.01639... |
| 195 | + assert_eq!(fused[0].file_path, "fts.md"); |
| 196 | + assert_eq!(fused[1].file_path, "sem.md"); |
| 197 | + |
| 198 | + let fts_expected = 3.0 / 61.0; |
| 199 | + let sem_expected = 1.0 / 61.0; |
| 200 | + assert!((fused[0].rrf_score - fts_expected).abs() < 1e-10); |
| 201 | + assert!((fused[1].rrf_score - sem_expected).abs() < 1e-10); |
| 202 | + } |
| 203 | + |
| 204 | + #[test] |
| 205 | + fn test_rrf_single_lane() { |
| 206 | + let semantic = vec![ |
| 207 | + make_result("a.md", 0.9), |
| 208 | + make_result("b.md", 0.8), |
| 209 | + make_result("c.md", 0.7), |
| 210 | + ]; |
| 211 | + |
| 212 | + let fused = rrf_fuse(&[("semantic", &semantic, 1.0)], 60); |
| 213 | + |
| 214 | + assert_eq!(fused.len(), 3); |
| 215 | + assert_eq!(fused[0].file_path, "a.md"); |
| 216 | + assert_eq!(fused[1].file_path, "b.md"); |
| 217 | + assert_eq!(fused[2].file_path, "c.md"); |
| 218 | + |
| 219 | + // Each should have exactly 1 lane contribution |
| 220 | + for f in &fused { |
| 221 | + assert_eq!(f.lane_contributions.len(), 1); |
| 222 | + assert_eq!(f.lane_contributions[0].lane_name, "semantic"); |
| 223 | + } |
| 224 | + } |
| 225 | + |
| 226 | + #[test] |
| 227 | + fn test_format_explain() { |
| 228 | + let result = FusedResult { |
| 229 | + file_path: "test.md".to_string(), |
| 230 | + file_id: 1, |
| 231 | + rrf_score: 0.0328, |
| 232 | + heading: None, |
| 233 | + snippet: "test".to_string(), |
| 234 | + docid: None, |
| 235 | + lane_contributions: vec![ |
| 236 | + LaneContribution { |
| 237 | + lane_name: "semantic".to_string(), |
| 238 | + rank: 1, |
| 239 | + raw_score: 0.87, |
| 240 | + weighted_contribution: 0.0164, |
| 241 | + }, |
| 242 | + LaneContribution { |
| 243 | + lane_name: "fts".to_string(), |
| 244 | + rank: 3, |
| 245 | + raw_score: 5.23, |
| 246 | + weighted_contribution: 0.0159, |
| 247 | + }, |
| 248 | + ], |
| 249 | + }; |
| 250 | + |
| 251 | + let output = format_explain(&result); |
| 252 | + assert!(output.contains("RRF: 0.0328")); |
| 253 | + assert!(output.contains("semantic: rank #1, raw 0.87, +0.0164")); |
| 254 | + assert!(output.contains("fts: rank #3, raw 5.23, +0.0159")); |
| 255 | + } |
| 256 | + |
| 257 | + #[test] |
| 258 | + fn test_rrf_empty_lanes() { |
| 259 | + let fused = rrf_fuse(&[], 60); |
| 260 | + assert!(fused.is_empty()); |
| 261 | + } |
| 262 | + |
| 263 | + #[test] |
| 264 | + fn test_rrf_empty_results() { |
| 265 | + let empty: Vec<RankedResult> = vec![]; |
| 266 | + let fused = rrf_fuse(&[("semantic", &empty, 1.0), ("fts", &empty, 1.0)], 60); |
| 267 | + assert!(fused.is_empty()); |
| 268 | + } |
| 269 | +} |
0 commit comments