|
| 1 | +use pgvector::Vector; |
| 2 | +use postgres::binary_copy::BinaryCopyInWriter; |
| 3 | +use postgres::types::{Kind, Type}; |
| 4 | +use postgres::{Client, NoTls}; |
| 5 | +use rand::Rng; |
| 6 | +use std::error::Error; |
| 7 | + |
| 8 | +fn main() -> Result<(), Box<dyn Error>> { |
| 9 | + // generate random data |
| 10 | + let rows = 100000; |
| 11 | + let dimensions = 128; |
| 12 | + let mut rng = rand::thread_rng(); |
| 13 | + let embeddings: Vec<Vec<f32>> = (0..rows) |
| 14 | + .map(|_| (0..dimensions).map(|_| rng.gen()).collect()) |
| 15 | + .collect(); |
| 16 | + let categories: Vec<i64> = (0..rows).map(|_| rng.gen_range(1..=100)).collect(); |
| 17 | + let queries: Vec<Vec<f32>> = (0..10) |
| 18 | + .map(|_| (0..dimensions).map(|_| rng.gen()).collect()) |
| 19 | + .collect(); |
| 20 | + |
| 21 | + // enable extensions |
| 22 | + let mut client = Client::configure() |
| 23 | + .host("localhost") |
| 24 | + .dbname("pgvector_citus") |
| 25 | + .user(std::env::var("USER")?.as_str()) |
| 26 | + .connect(NoTls)?; |
| 27 | + client.execute("CREATE EXTENSION IF NOT EXISTS citus", &[])?; |
| 28 | + client.execute("CREATE EXTENSION IF NOT EXISTS vector", &[])?; |
| 29 | + |
| 30 | + // GUC variables set on the session do not propagate to Citus workers |
| 31 | + // https://github.com/citusdata/citus/issues/462 |
| 32 | + // you can either: |
| 33 | + // 1. set them on the system, user, or database and reconnect |
| 34 | + // 2. set them for a transaction with SET LOCAL |
| 35 | + client.execute( |
| 36 | + "ALTER DATABASE pgvector_citus SET maintenance_work_mem = '512MB'", |
| 37 | + &[], |
| 38 | + )?; |
| 39 | + client.execute("ALTER DATABASE pgvector_citus SET hnsw.ef_search = 20", &[])?; |
| 40 | + client.close()?; |
| 41 | + |
| 42 | + // reconnect for updated GUC variables to take effect |
| 43 | + let mut client = Client::configure() |
| 44 | + .host("localhost") |
| 45 | + .dbname("pgvector_citus") |
| 46 | + .user(std::env::var("USER")?.as_str()) |
| 47 | + .connect(NoTls)?; |
| 48 | + |
| 49 | + println!("Creating distributed table"); |
| 50 | + client.execute("DROP TABLE IF EXISTS items", &[])?; |
| 51 | + client.execute( |
| 52 | + &format!("CREATE TABLE items (id bigserial, embedding vector({dimensions}), category_id bigint, PRIMARY KEY (id, category_id))"), |
| 53 | + &[], |
| 54 | + )?; |
| 55 | + client.execute("SET citus.shard_count = 4", &[])?; |
| 56 | + client.execute( |
| 57 | + "SELECT create_distributed_table('items', 'category_id')", |
| 58 | + &[], |
| 59 | + )?; |
| 60 | + |
| 61 | + println!("Loading data in parallel"); |
| 62 | + let vector_type = get_type(&mut client, "vector")?; |
| 63 | + let writer = |
| 64 | + client.copy_in("COPY items (embedding, category_id) FROM STDIN WITH (FORMAT BINARY)")?; |
| 65 | + let mut writer = BinaryCopyInWriter::new(writer, &[vector_type, Type::INT8]); |
| 66 | + for (embedding, category) in embeddings.into_iter().zip(categories) { |
| 67 | + writer.write(&[&Vector::from(embedding), &category])?; |
| 68 | + } |
| 69 | + writer.finish()?; |
| 70 | + |
| 71 | + println!("Creating index in parallel"); |
| 72 | + client.execute( |
| 73 | + "CREATE INDEX ON items USING hnsw (embedding vector_l2_ops)", |
| 74 | + &[], |
| 75 | + )?; |
| 76 | + |
| 77 | + println!("Running distributed queries"); |
| 78 | + for query in queries { |
| 79 | + let rows = client.query( |
| 80 | + "SELECT id FROM items ORDER BY embedding <-> $1 LIMIT 10", |
| 81 | + &[&Vector::from(query)], |
| 82 | + )?; |
| 83 | + let ids: Vec<i64> = rows.into_iter().map(|row| row.get(0)).collect(); |
| 84 | + println!("{:?}", ids); |
| 85 | + } |
| 86 | + |
| 87 | + Ok(()) |
| 88 | +} |
| 89 | + |
| 90 | +fn get_type(client: &mut Client, name: &str) -> Result<Type, Box<dyn Error>> { |
| 91 | + let row = client.query_one("SELECT pg_type.oid, nspname AS schema FROM pg_type INNER JOIN pg_namespace ON pg_namespace.oid = pg_type.typnamespace WHERE typname = $1", &[&name])?; |
| 92 | + Ok(Type::new( |
| 93 | + name.into(), |
| 94 | + row.get("oid"), |
| 95 | + Kind::Simple, |
| 96 | + row.get("schema"), |
| 97 | + )) |
| 98 | +} |
0 commit comments