Skip to content

Commit dc01f9e

Browse files
devwhodevsclaude
andcommitted
feat: pluggable model layer with ModelBackend trait
Extract embedding into a ModelBackend trait. Existing ONNX embedder implements the trait. Users can configure models in vault.toml and manage them via 'engraph models list/info'. Registry ships with known-good models. Prepare for future GGUF adapter. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent e367204 commit dc01f9e

4 files changed

Lines changed: 211 additions & 0 deletions

File tree

src/embedder.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,28 @@ impl Embedder {
157157
}
158158
}
159159

160+
impl crate::model::ModelBackend for Embedder {
161+
fn embed_batch(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
162+
self.embed_batch(texts)
163+
}
164+
165+
fn embed_one(&mut self, text: &str) -> Result<Vec<f32>> {
166+
self.embed_one(text)
167+
}
168+
169+
fn token_count(&self, text: &str) -> usize {
170+
self.token_count(text)
171+
}
172+
173+
fn dim(&self) -> usize {
174+
EMBEDDING_DIM
175+
}
176+
177+
fn name(&self) -> &str {
178+
"onnx:all-MiniLM-L6-v2"
179+
}
180+
}
181+
160182
/// L2-normalize a vector. Returns a zero vector if input norm is zero.
161183
fn normalize_vector(v: &[f32]) -> Vec<f32> {
162184
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub mod fts;
66
pub mod fusion;
77
pub mod hnsw;
88
pub mod indexer;
9+
pub mod model;
910
pub mod profile;
1011
pub mod search;
1112
pub mod store;

src/main.rs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use engraph::config;
22
use engraph::indexer;
3+
use engraph::model;
34
use engraph::profile;
45
use engraph::search;
56
use engraph::store;
@@ -74,6 +75,20 @@ enum Command {
7475

7576
/// Interactively configure vault profile.
7677
Configure,
78+
79+
/// Manage embedding models.
80+
Models {
81+
#[command(subcommand)]
82+
action: ModelsAction,
83+
},
84+
}
85+
86+
#[derive(Subcommand, Debug)]
87+
enum ModelsAction {
88+
/// List available models.
89+
List,
90+
/// Show info about a model.
91+
Info { name: String },
7792
}
7893

7994
/// Check whether an index has been built by looking for engraph.db in data_dir.
@@ -296,6 +311,38 @@ fn main() -> Result<()> {
296311
"Interactive configuration not yet implemented. Run 'engraph init' for auto-detection."
297312
);
298313
}
314+
315+
Command::Models { action } => {
316+
let registry = model::ModelRegistry::default();
317+
match action {
318+
ModelsAction::List => {
319+
println!("{:<30} {:>5} {}", "NAME", "DIM", "DESCRIPTION");
320+
println!("{}", "-".repeat(70));
321+
for entry in &registry.entries {
322+
println!(
323+
"{:<30} {:>5} {}",
324+
entry.name, entry.dim, entry.description
325+
);
326+
}
327+
}
328+
ModelsAction::Info { name } => {
329+
if let Some(entry) = registry.get(&name) {
330+
println!("Name: {}", entry.name);
331+
println!("Format: {:?}", entry.format);
332+
println!("Dimensions: {}", entry.dim);
333+
println!("SHA-256: {}", entry.sha256);
334+
println!("URL: {}", entry.url);
335+
println!("Description: {}", entry.description);
336+
} else {
337+
eprintln!("Unknown model: {name}");
338+
eprintln!(
339+
"Run 'engraph models list' to see available models."
340+
);
341+
std::process::exit(1);
342+
}
343+
}
344+
}
345+
}
299346
}
300347

301348
Ok(())

src/model.rs

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
use anyhow::Result;
2+
use serde::{Deserialize, Serialize};
3+
4+
/// Trait for embedding backends. Any model that can embed text implements this.
5+
pub trait ModelBackend {
6+
fn embed_batch(&mut self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
7+
fn embed_one(&mut self, text: &str) -> Result<Vec<f32>>;
8+
fn token_count(&self, text: &str) -> usize;
9+
fn dim(&self) -> usize;
10+
fn name(&self) -> &str;
11+
}
12+
13+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
14+
pub enum ModelFormat {
15+
Onnx,
16+
Gguf,
17+
File,
18+
}
19+
20+
#[derive(Debug, Clone)]
21+
pub struct ModelSpec {
22+
pub format: ModelFormat,
23+
pub name: String,
24+
pub path: String,
25+
}
26+
27+
#[derive(Debug, Clone, Serialize, Deserialize)]
28+
pub struct ModelRegistryEntry {
29+
pub name: String,
30+
pub format: ModelFormat,
31+
pub url: String,
32+
pub sha256: String,
33+
pub dim: usize,
34+
pub description: String,
35+
}
36+
37+
pub struct ModelRegistry {
38+
pub entries: Vec<ModelRegistryEntry>,
39+
}
40+
41+
impl Default for ModelRegistry {
42+
fn default() -> Self {
43+
Self {
44+
entries: vec![ModelRegistryEntry {
45+
name: "onnx:all-MiniLM-L6-v2".to_string(),
46+
format: ModelFormat::Onnx,
47+
url: "https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/model.onnx".to_string(),
48+
sha256: "6fd5d72fe4589f189f8ebc006442dbb529bb7ce38f8082112682524616046452".to_string(),
49+
dim: 384,
50+
description: "Lightweight general-purpose sentence embeddings".to_string(),
51+
}],
52+
}
53+
}
54+
}
55+
56+
impl ModelRegistry {
57+
pub fn get(&self, name: &str) -> Option<&ModelRegistryEntry> {
58+
self.entries.iter().find(|e| e.name == name)
59+
}
60+
}
61+
62+
pub fn parse_model_spec(spec: &str) -> ModelSpec {
63+
if let Some(path) = spec.strip_prefix("file:") {
64+
return ModelSpec {
65+
format: ModelFormat::File,
66+
name: spec.to_string(),
67+
path: path.to_string(),
68+
};
69+
}
70+
if let Some((format_str, name)) = spec.split_once(':') {
71+
let format = match format_str {
72+
"onnx" => ModelFormat::Onnx,
73+
"gguf" => ModelFormat::Gguf,
74+
_ => ModelFormat::Onnx,
75+
};
76+
ModelSpec {
77+
format,
78+
name: name.to_string(),
79+
path: String::new(),
80+
}
81+
} else {
82+
ModelSpec {
83+
format: ModelFormat::Onnx,
84+
name: spec.to_string(),
85+
path: String::new(),
86+
}
87+
}
88+
}
89+
90+
#[cfg(test)]
91+
mod tests {
92+
use super::*;
93+
94+
#[test]
95+
fn test_model_registry_default() {
96+
let registry = ModelRegistry::default();
97+
assert_eq!(registry.entries.len(), 1);
98+
let entry = &registry.entries[0];
99+
assert_eq!(entry.name, "onnx:all-MiniLM-L6-v2");
100+
assert_eq!(entry.dim, 384);
101+
assert_eq!(entry.format, ModelFormat::Onnx);
102+
}
103+
104+
#[test]
105+
fn test_parse_model_spec_onnx() {
106+
let spec = parse_model_spec("onnx:all-MiniLM-L6-v2");
107+
assert_eq!(spec.format, ModelFormat::Onnx);
108+
assert_eq!(spec.name, "all-MiniLM-L6-v2");
109+
assert!(spec.path.is_empty());
110+
}
111+
112+
#[test]
113+
fn test_parse_model_spec_file() {
114+
let spec = parse_model_spec("file:/path/to/model.onnx");
115+
assert_eq!(spec.format, ModelFormat::File);
116+
assert_eq!(spec.name, "file:/path/to/model.onnx");
117+
assert_eq!(spec.path, "/path/to/model.onnx");
118+
}
119+
120+
#[test]
121+
fn test_parse_model_spec_bare() {
122+
let spec = parse_model_spec("my-custom-model");
123+
assert_eq!(spec.format, ModelFormat::Onnx);
124+
assert_eq!(spec.name, "my-custom-model");
125+
assert!(spec.path.is_empty());
126+
}
127+
128+
#[test]
129+
fn test_registry_get_existing() {
130+
let registry = ModelRegistry::default();
131+
let entry = registry.get("onnx:all-MiniLM-L6-v2");
132+
assert!(entry.is_some());
133+
assert_eq!(entry.unwrap().dim, 384);
134+
}
135+
136+
#[test]
137+
fn test_registry_get_missing() {
138+
let registry = ModelRegistry::default();
139+
assert!(registry.get("nonexistent-model").is_none());
140+
}
141+
}

0 commit comments

Comments
 (0)