Skip to content

Commit 09253f9

Browse files
committed
Added Citus example [skip ci]
1 parent 999a79a commit 09253f9

3 files changed

Lines changed: 112 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Or check out some examples:
2121
- [Sentence embeddings](https://github.com/pgvector/pgvector-rust/blob/master/examples/candle/src/main.rs) with Candle
2222
- [Hybrid search](https://github.com/pgvector/pgvector-rust/blob/master/examples/hybrid_search/src/main.rs) with Candle (Reciprocal Rank Fusion)
2323
- [Recommendations](https://github.com/pgvector/pgvector-rust/blob/master/examples/disco/src/main.rs) with Disco
24+
- [Horizontal scaling](https://github.com/pgvector/pgvector-rust/blob/master/examples/citus/src/main.rs) with Citus
2425
- [Bulk loading](https://github.com/pgvector/pgvector-rust/blob/master/examples/loading/src/main.rs) with `COPY`
2526

2627
## Rust-Postgres

examples/citus/Cargo.toml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
[package]
2+
name = "example"
3+
version = "0.1.0"
4+
edition = "2021"
5+
publish = false
6+
7+
[dependencies]
8+
pgvector = { path = "../..", features = ["postgres"] }
9+
postgres = "0.19"
10+
rand = "0.8"
11+
12+
[profile.dev]
13+
opt-level = 1

examples/citus/src/main.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
use pgvector::Vector;
2+
use postgres::binary_copy::BinaryCopyInWriter;
3+
use postgres::types::{Kind, Type};
4+
use postgres::{Client, NoTls};
5+
use rand::Rng;
6+
use std::error::Error;
7+
8+
fn main() -> Result<(), Box<dyn Error>> {
9+
// generate random data
10+
let rows = 100000;
11+
let dimensions = 128;
12+
let mut rng = rand::thread_rng();
13+
let embeddings: Vec<Vec<f32>> = (0..rows)
14+
.map(|_| (0..dimensions).map(|_| rng.gen()).collect())
15+
.collect();
16+
let categories: Vec<i64> = (0..rows).map(|_| rng.gen_range(1..=100)).collect();
17+
let queries: Vec<Vec<f32>> = (0..10)
18+
.map(|_| (0..dimensions).map(|_| rng.gen()).collect())
19+
.collect();
20+
21+
// enable extensions
22+
let mut client = Client::configure()
23+
.host("localhost")
24+
.dbname("pgvector_citus")
25+
.user(std::env::var("USER")?.as_str())
26+
.connect(NoTls)?;
27+
client.execute("CREATE EXTENSION IF NOT EXISTS citus", &[])?;
28+
client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?;
29+
30+
// GUC variables set on the session do not propagate to Citus workers
31+
// https://github.com/citusdata/citus/issues/462
32+
// you can either:
33+
// 1. set them on the system, user, or database and reconnect
34+
// 2. set them for a transaction with SET LOCAL
35+
client.execute(
36+
"ALTER DATABASE pgvector_citus SET maintenance_work_mem = '512MB'",
37+
&[],
38+
)?;
39+
client.execute("ALTER DATABASE pgvector_citus SET hnsw.ef_search = 20", &[])?;
40+
client.close()?;
41+
42+
// reconnect for updated GUC variables to take effect
43+
let mut client = Client::configure()
44+
.host("localhost")
45+
.dbname("pgvector_citus")
46+
.user(std::env::var("USER")?.as_str())
47+
.connect(NoTls)?;
48+
49+
println!("Creating distributed table");
50+
client.execute("DROP TABLE IF EXISTS items", &[])?;
51+
client.execute(
52+
&format!("CREATE TABLE items (id bigserial, embedding vector({dimensions}), category_id bigint, PRIMARY KEY (id, category_id))"),
53+
&[],
54+
)?;
55+
client.execute("SET citus.shard_count = 4", &[])?;
56+
client.execute(
57+
"SELECT create_distributed_table('items', 'category_id')",
58+
&[],
59+
)?;
60+
61+
println!("Loading data in parallel");
62+
let vector_type = get_type(&mut client, "vector")?;
63+
let writer =
64+
client.copy_in("COPY items (embedding, category_id) FROM STDIN WITH (FORMAT BINARY)")?;
65+
let mut writer = BinaryCopyInWriter::new(writer, &[vector_type, Type::INT8]);
66+
for (embedding, category) in embeddings.into_iter().zip(categories) {
67+
writer.write(&[&Vector::from(embedding), &category])?;
68+
}
69+
writer.finish()?;
70+
71+
println!("Creating index in parallel");
72+
client.execute(
73+
"CREATE INDEX ON items USING hnsw (embedding vector_l2_ops)",
74+
&[],
75+
)?;
76+
77+
println!("Running distributed queries");
78+
for query in queries {
79+
let rows = client.query(
80+
"SELECT id FROM items ORDER BY embedding <-> $1 LIMIT 10",
81+
&[&Vector::from(query)],
82+
)?;
83+
let ids: Vec<i64> = rows.into_iter().map(|row| row.get(0)).collect();
84+
println!("{:?}", ids);
85+
}
86+
87+
Ok(())
88+
}
89+
90+
fn get_type(client: &mut Client, name: &str) -> Result<Type, Box<dyn Error>> {
91+
let row = client.query_one("SELECT pg_type.oid, nspname AS schema FROM pg_type INNER JOIN pg_namespace ON pg_namespace.oid = pg_type.typnamespace WHERE typname = $1", &[&name])?;
92+
Ok(Type::new(
93+
name.into(),
94+
row.get("oid"),
95+
Kind::Simple,
96+
row.get("schema"),
97+
))
98+
}

0 commit comments

Comments
 (0)