Skip to content

Commit 5995107

Browse files
committed
Added sparse search example [skip ci]
1 parent cbe9e7b commit 5995107

3 files changed

Lines changed: 150 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ Or check out some examples:
3939
- [Binary embeddings](examples/cohere/src/main/java/com/example/Example.java) with Cohere
4040
- [Sentence embeddings](examples/djl/src/main/java/com/example/Example.java) with Deep Java Library
4141
- [Hybrid search](examples/hybrid/src/main/java/com/example/Example.java) with Deep Java Library (Reciprocal Rank Fusion)
42+
- [Sparse search](examples/sparse/src/main/java/com/example/Example.java) with Text Embeddings Inference
4243
- [Extended-connectivity fingerprints](examples/cdk/src/main/java/com/example/Example.java) with the Chemistry Development Kit
4344
- [Horizontal scaling](examples/citus/src/main/java/com/example/Example.java) with Citus
4445
- [Bulk loading](examples/loading/src/main/java/com/example/Example.java) with `COPY`

examples/sparse/pom.xml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
<?xml version='1.0' encoding='UTF-8'?>
2+
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://maven.apache.org/POM/4.0.0">
3+
<modelVersion>4.0.0</modelVersion>
4+
<groupId>com.example</groupId>
5+
<artifactId>example</artifactId>
6+
<version>1</version>
7+
<properties>
8+
<maven.compiler.release>11</maven.compiler.release>
9+
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
10+
</properties>
11+
<dependencies>
12+
<dependency>
13+
<groupId>org.postgresql</groupId>
14+
<artifactId>postgresql</artifactId>
15+
<version>42.7.3</version>
16+
</dependency>
17+
<dependency>
18+
<groupId>com.pgvector</groupId>
19+
<artifactId>pgvector</artifactId>
20+
<version>0.1.6</version>
21+
</dependency>
22+
<dependency>
23+
<groupId>com.fasterxml.jackson.core</groupId>
24+
<artifactId>jackson-databind</artifactId>
25+
<version>2.16.0</version>
26+
</dependency>
27+
</dependencies>
28+
<build>
29+
<plugins>
30+
<plugin>
31+
<artifactId>maven-assembly-plugin</artifactId>
32+
<version>3.7.1</version>
33+
<configuration>
34+
<descriptorRefs>
35+
<descriptorRef>jar-with-dependencies</descriptorRef>
36+
</descriptorRefs>
37+
<archive>
38+
<manifest>
39+
<mainClass>com.example.Example</mainClass>
40+
</manifest>
41+
</archive>
42+
<finalName>example</finalName>
43+
</configuration>
44+
<executions>
45+
<execution>
46+
<id>make-assembly</id>
47+
<phase>package</phase>
48+
<goals>
49+
<goal>single</goal>
50+
</goals>
51+
</execution>
52+
</executions>
53+
</plugin>
54+
</plugins>
55+
</build>
56+
</project>
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package com.example;
2+
3+
import java.io.IOException;
4+
import java.net.URI;
5+
import java.net.http.HttpClient;
6+
import java.net.http.HttpRequest;
7+
import java.net.http.HttpRequest.BodyPublishers;
8+
import java.net.http.HttpResponse;
9+
import java.net.http.HttpResponse.BodyHandlers;
10+
import java.sql.Connection;
11+
import java.sql.DriverManager;
12+
import java.sql.PreparedStatement;
13+
import java.sql.ResultSet;
14+
import java.sql.SQLException;
15+
import java.sql.Statement;
16+
import java.util.ArrayList;
17+
import java.util.HashMap;
18+
import java.util.List;
19+
import java.util.Map;
20+
import com.fasterxml.jackson.databind.ObjectMapper;
21+
import com.fasterxml.jackson.databind.JsonNode;
22+
import com.fasterxml.jackson.databind.node.ObjectNode;
23+
import com.pgvector.PGsparsevec;
24+
import com.pgvector.PGvector;
25+
26+
public class Example {
27+
public static void main(String[] args) throws IOException, InterruptedException, SQLException {
28+
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_example");
29+
30+
Statement setupStmt = conn.createStatement();
31+
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
32+
setupStmt.executeUpdate("DROP TABLE IF EXISTS documents");
33+
34+
PGvector.addVectorType(conn);
35+
36+
Statement createStmt = conn.createStatement();
37+
createStmt.executeUpdate("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding sparsevec(30522))");
38+
39+
String[] input = {
40+
"The dog is barking",
41+
"The cat is purring",
42+
"The bear is growling"
43+
};
44+
List<Map<Integer, Float>> embeddings = fetchEmbeddings(input);
45+
46+
for (int i = 0; i < input.length; i++) {
47+
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO documents (content, embedding) VALUES (?, ?)");
48+
insertStmt.setString(1, input[i]);
49+
insertStmt.setObject(2, new PGsparsevec(embeddings.get(i), 30522));
50+
insertStmt.executeUpdate();
51+
}
52+
53+
String query = "forest";
54+
Map<Integer, Float> queryEmbedding = fetchEmbeddings(new String[] { query }).get(0);
55+
PreparedStatement neighborStmt = conn.prepareStatement("SELECT content FROM documents ORDER BY embedding <#> ? LIMIT 5");
56+
neighborStmt.setObject(1, new PGsparsevec(queryEmbedding, 30522));
57+
ResultSet rs = neighborStmt.executeQuery();
58+
while (rs.next()) {
59+
System.out.println(rs.getString("content"));
60+
}
61+
62+
conn.close();
63+
}
64+
65+
private static List<Map<Integer, Float>> fetchEmbeddings(String[] inputs) throws IOException, InterruptedException {
66+
ObjectMapper mapper = new ObjectMapper();
67+
ObjectNode root = mapper.createObjectNode();
68+
for (String v : inputs) {
69+
root.withArray("inputs").add(v);
70+
}
71+
String json = mapper.writeValueAsString(root);
72+
73+
HttpClient client = HttpClient.newHttpClient();
74+
HttpRequest request = HttpRequest.newBuilder()
75+
.uri(URI.create("http://localhost:3000/embed_sparse"))
76+
.header("Content-Type", "application/json")
77+
.POST(BodyPublishers.ofString(json))
78+
.build();
79+
HttpResponse<String> response = client.send(request, BodyHandlers.ofString());
80+
81+
List<Map<Integer, Float>> embeddings = new ArrayList<>();
82+
for (JsonNode n : mapper.readTree(response.body())) {
83+
Map<Integer, Float> embedding = new HashMap<Integer, Float>();
84+
for (JsonNode v : n) {
85+
int index = v.get("index").asInt();
86+
float value = (float) v.get("value").asDouble();
87+
embedding.put(Integer.valueOf(index), Float.valueOf(value));
88+
}
89+
embeddings.add(embedding);
90+
}
91+
return embeddings;
92+
}
93+
}

0 commit comments

Comments
 (0)