Skip to content

Commit 4a95ca0

Browse files
committed
Updated example [skip ci]
1 parent 6a7d69a commit 4a95ca0

2 files changed

Lines changed: 7 additions & 8 deletions

File tree

examples/hybrid_search/Cargo.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ edition = "2021"
55
publish = false
66

77
[dependencies]
8-
candle-core = "0.6"
9-
candle-nn = "0.6"
10-
candle-transformers = "0.6"
11-
hf-hub = "0.3"
8+
candle-core = "0.8"
9+
candle-nn = "0.8"
10+
candle-transformers = "0.8"
11+
hf-hub = "0.4"
1212
pgvector = { path = "../..", features = ["postgres"] }
1313
postgres = "0.19"
1414
serde_json = "1"
15-
tokenizers = "0.19"
15+
tokenizers = "0.21"

examples/hybrid_search/src/main.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,13 +113,12 @@ impl EmbeddingModel {
113113
Ok(Self { tokenizer, model })
114114
}
115115

116-
// embed one at a time since BertModel does not support attention mask
117-
// https://github.com/huggingface/candle/issues/1798
116+
// TODO support multiple texts
118117
fn embed(&self, text: &str) -> Result<Vec<f32>, Box<dyn Error + Send + Sync>> {
119118
let tokens = self.tokenizer.encode(text, true)?;
120119
let token_ids = Tensor::new(vec![tokens.get_ids().to_vec()], &self.model.device)?;
121120
let token_type_ids = token_ids.zeros_like()?;
122-
let embeddings = self.model.forward(&token_ids, &token_type_ids)?;
121+
let embeddings = self.model.forward(&token_ids, &token_type_ids, None)?;
123122
let embeddings = (embeddings.sum(1)? / (embeddings.dim(1)? as f64))?;
124123
let embeddings = embeddings.broadcast_div(&embeddings.sqr()?.sum_keepdim(1)?.sqrt()?)?;
125124
Ok(embeddings.squeeze(0)?.to_vec1::<f32>()?)

0 commit comments

Comments
 (0)