Skip to content

Commit 9a2b71f

Browse files
committed
Added Citus example [skip ci]
1 parent 612df9a commit 9a2b71f

2 files changed

Lines changed: 97 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ Or check out some examples:
3737

3838
- [Embeddings](src/test/java/com/pgvector/OpenAITest.java) with OpenAI
3939
- [Binary embeddings](src/test/java/com/pgvector/CohereTest.java) with Cohere
40+
- [Horizontal scaling](src/test/java/com/pgvector/CitusTest.java) with Citus
4041
- [Bulk loading](src/test/java/com/pgvector/LoadingTest.java) with `COPY`
4142

4243
## JDBC (Java)
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
package com.pgvector;
2+
3+
import java.io.UnsupportedEncodingException;
4+
import java.sql.Connection;
5+
import java.sql.DriverManager;
6+
import java.sql.PreparedStatement;
7+
import java.sql.ResultSet;
8+
import java.sql.SQLException;
9+
import java.sql.Statement;
10+
import java.util.ArrayList;
11+
import java.util.Random;
12+
import com.pgvector.PGvector;
13+
import org.postgresql.PGConnection;
14+
import org.postgresql.copy.CopyIn;
15+
import org.postgresql.copy.CopyManager;
16+
import org.postgresql.core.BaseConnection;
17+
import org.junit.jupiter.api.Test;
18+
19+
public class CitusTest {
20+
@Test
21+
void example() throws SQLException {
22+
if (System.getenv("TEST_CITUS") == null) {
23+
return;
24+
}
25+
26+
// generate data
27+
int rows = 1000000;
28+
int dimensions = 128;
29+
ArrayList<float[]> embeddings = new ArrayList<>(rows);
30+
ArrayList<Integer> categories = new ArrayList<>(rows);
31+
Random rnd = new Random();
32+
for (int i = 0; i < rows; i++) {
33+
float[] embedding = new float[dimensions];
34+
for (int j = 0; j < dimensions; j++) {
35+
embedding[j] = (float) Math.random();
36+
}
37+
embeddings.add(embedding);
38+
categories.add(rnd.nextInt(100));
39+
}
40+
41+
// enable extensions
42+
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_citus");
43+
Statement setupStmt = conn.createStatement();
44+
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS citus");
45+
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
46+
47+
// GUC variables set on the session do not propagate to Citus workers
48+
// https://github.com/citusdata/citus/issues/462
49+
// you can either:
50+
// 1. set them on the system, user, or database and reconnect
51+
// 2. set them for a transaction with SET LOCAL
52+
setupStmt.executeUpdate("ALTER DATABASE pgvector_citus SET maintenance_work_mem = '512MB'");
53+
setupStmt.executeUpdate("ALTER DATABASE pgvector_citus SET hnsw.ef_search = 20");
54+
conn.close();
55+
56+
// reconnect for updated GUC variables to take effect
57+
conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_citus");
58+
PGvector.addVectorType(conn);
59+
60+
System.out.println("Creating distributed table");
61+
setupStmt = conn.createStatement();
62+
setupStmt.executeUpdate("DROP TABLE IF EXISTS items");
63+
setupStmt.executeUpdate(String.format("CREATE TABLE items (id bigserial, embedding vector(%d), category_id bigint, PRIMARY KEY (id, category_id))", dimensions));
64+
setupStmt.executeUpdate("SET citus.shard_count = 4");
65+
setupStmt.executeQuery("SELECT create_distributed_table('items', 'category_id')");
66+
67+
System.out.println("Loading data in parallel");
68+
CopyManager copyManager = new CopyManager((BaseConnection) conn);
69+
// TODO use binary format
70+
CopyIn copyIn = copyManager.copyIn("COPY items (embedding, category_id) FROM STDIN");
71+
for (int i = 0; i < rows; i++) {
72+
PGvector embedding = new PGvector(embeddings.get(i));
73+
byte[] bytes = String.format("%s\t%d\n", embedding.getValue(), categories.get(i)).getBytes();
74+
copyIn.writeToCopy(bytes, 0, bytes.length);
75+
}
76+
copyIn.endCopy();
77+
78+
System.out.println("Creating index in parallel");
79+
Statement createIndexStmt = conn.createStatement();
80+
createIndexStmt.executeUpdate("CREATE INDEX ON items USING hnsw (embedding vector_l2_ops)");
81+
82+
System.out.println("Running distributed queries");
83+
for (int i = 0; i < 10; i++) {
84+
PreparedStatement queryStmt = conn.prepareStatement("SELECT id FROM items ORDER BY embedding <-> ? LIMIT 10");
85+
queryStmt.setObject(1, new PGvector(embeddings.get(rnd.nextInt(rows))));
86+
ResultSet rs = queryStmt.executeQuery();
87+
ArrayList<Long> ids = new ArrayList<>();
88+
while (rs.next()) {
89+
ids.add(rs.getLong("id"));
90+
}
91+
System.out.println(ids.toString());
92+
}
93+
94+
conn.close();
95+
}
96+
}

0 commit comments

Comments
 (0)