Skip to content

Commit 5b3bef1

Browse files
committed
Added Cohere example [skip ci]
1 parent 5354be1 commit 5b3bef1

2 files changed

Lines changed: 102 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ And follow the instructions for your database library:
3636
Or check out an example:
3737

3838
- [Embeddings](src/test/java/com/pgvector/OpenAITest.java) with OpenAI
39+
- [Binary embeddings](src/test/java/com/pgvector/CohereTest.java) with Cohere
3940

4041
## JDBC (Java)
4142

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
package com.pgvector;
2+
3+
import java.io.IOException;
4+
import java.net.URI;
5+
import java.net.http.HttpClient;
6+
import java.net.http.HttpRequest;
7+
import java.net.http.HttpRequest.BodyPublishers;
8+
import java.net.http.HttpResponse;
9+
import java.net.http.HttpResponse.BodyHandlers;
10+
import java.sql.Connection;
11+
import java.sql.DriverManager;
12+
import java.sql.PreparedStatement;
13+
import java.sql.ResultSet;
14+
import java.sql.SQLException;
15+
import java.sql.Statement;
16+
import java.util.ArrayList;
17+
import java.util.List;
18+
import com.fasterxml.jackson.databind.ObjectMapper;
19+
import com.fasterxml.jackson.databind.JsonNode;
20+
import com.fasterxml.jackson.databind.node.ObjectNode;
21+
import com.pgvector.PGvector;
22+
import org.postgresql.PGConnection;
23+
import org.junit.jupiter.api.Test;
24+
25+
public class CohereTest {
26+
@Test
27+
void example() throws IOException, InterruptedException, SQLException {
28+
String apiKey = System.getenv("CO_API_KEY");
29+
if (apiKey == null) {
30+
return;
31+
}
32+
33+
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_example");
34+
35+
Statement setupStmt = conn.createStatement();
36+
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
37+
setupStmt.executeUpdate("DROP TABLE IF EXISTS documents");
38+
39+
PGvector.addVectorType(conn);
40+
41+
Statement createStmt = conn.createStatement();
42+
createStmt.executeUpdate("CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding bit(1024))");
43+
44+
String[] input = {
45+
"The dog is barking",
46+
"The cat is purring",
47+
"The bear is growling"
48+
};
49+
List<byte[]> embeddings = fetchEmbeddings(input, "search_document", apiKey);
50+
for (int i = 0; i < input.length; i++) {
51+
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO documents (content, embedding) VALUES (?, ?)");
52+
insertStmt.setString(1, input[i]);
53+
insertStmt.setObject(2, new PGbit(embeddings.get(i)));
54+
insertStmt.executeUpdate();
55+
}
56+
57+
String query = "forest";
58+
byte[] queryEmbedding = fetchEmbeddings(new String[] {query}, "search_query", apiKey).get(0);
59+
PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM documents ORDER BY embedding <~> ? LIMIT 5");
60+
neighborStmt.setObject(1, new PGbit(queryEmbedding));
61+
ResultSet rs = neighborStmt.executeQuery();
62+
while (rs.next()) {
63+
System.out.println(rs.getString("content"));
64+
}
65+
66+
conn.close();
67+
}
68+
69+
// https://docs.cohere.com/reference/embed
70+
private List<byte[]> fetchEmbeddings(String[] texts, String inputType, String apiKey) throws IOException, InterruptedException {
71+
ObjectMapper mapper = new ObjectMapper();
72+
ObjectNode root = mapper.createObjectNode();
73+
for (String v : texts) {
74+
root.withArray("texts").add(v);
75+
}
76+
root.put("model", "embed-english-v3.0");
77+
root.put("input_type", inputType);
78+
root.withArray("embedding_types").add("ubinary");
79+
String json = mapper.writeValueAsString(root);
80+
81+
HttpClient client = HttpClient.newHttpClient();
82+
HttpRequest request = HttpRequest.newBuilder()
83+
.uri(URI.create("https://api.cohere.com/v1/embed"))
84+
.header("Authorization", "Bearer " + apiKey)
85+
.header("Content-Type", "application/json")
86+
.POST(BodyPublishers.ofString(json))
87+
.build();
88+
HttpResponse<String> response = client.send(request, BodyHandlers.ofString());
89+
90+
List<byte[]> embeddings = new ArrayList<>();
91+
for (JsonNode n : mapper.readTree(response.body()).get("embeddings").get("ubinary")) {
92+
byte[] embedding = new byte[n.size()];
93+
int i = 0;
94+
for (JsonNode v : n) {
95+
embedding[i++] = (byte) v.asDouble();
96+
}
97+
embeddings.add(embedding);
98+
}
99+
return embeddings;
100+
}
101+
}

0 commit comments

Comments
 (0)