Skip to content

Commit 09cb16f

Browse files
committed
Improved examples [skip ci]
1 parent b8458d5 commit 09cb16f

4 files changed

Lines changed: 16 additions & 15 deletions

File tree

examples/cohere/src/main/java/com/example/Example.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ public static void main(String[] args) throws IOException, InterruptedException,
4242
"The cat is purring",
4343
"The bear is growling"
4444
};
45-
List<byte[]> embeddings = fetchEmbeddings(input, "search_document", apiKey);
45+
List<byte[]> embeddings = embed(input, "search_document", apiKey);
46+
4647
for (int i = 0; i < input.length; i++) {
4748
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO documents (content, embedding) VALUES (?, ?)");
4849
insertStmt.setString(1, input[i]);
@@ -51,7 +52,7 @@ public static void main(String[] args) throws IOException, InterruptedException,
5152
}
5253

5354
String query = "forest";
54-
byte[] queryEmbedding = fetchEmbeddings(new String[] {query}, "search_query", apiKey).get(0);
55+
byte[] queryEmbedding = embed(new String[] {query}, "search_query", apiKey).get(0);
5556
PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM documents ORDER BY embedding <~> ? LIMIT 5");
5657
neighborStmt.setObject(1, new PGbit(queryEmbedding));
5758
ResultSet rs = neighborStmt.executeQuery();
@@ -63,7 +64,7 @@ public static void main(String[] args) throws IOException, InterruptedException,
6364
}
6465

6566
// https://docs.cohere.com/reference/embed
66-
private static List<byte[]> fetchEmbeddings(String[] texts, String inputType, String apiKey) throws IOException, InterruptedException {
67+
private static List<byte[]> embed(String[] texts, String inputType, String apiKey) throws IOException, InterruptedException {
6768
ObjectMapper mapper = new ObjectMapper();
6869
ObjectNode root = mapper.createObjectNode();
6970
for (String v : texts) {

examples/hybrid/src/main/java/com/example/Example.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ public static void main(String[] args) throws IOException, ModelException, SQLEx
3737
"The cat is purring",
3838
"The bear is growling"
3939
};
40-
List<float[]> embeddings = generateEmbeddings(model, input);
40+
List<float[]> embeddings = embed(model, input);
4141

4242
for (int i = 0; i < input.length; i++) {
4343
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO documents (content, embedding) VALUES (?, ?)");
@@ -47,7 +47,7 @@ public static void main(String[] args) throws IOException, ModelException, SQLEx
4747
}
4848

4949
String query = "growling bear";
50-
float[] queryEmbedding = generateEmbeddings(model, new String[] {query}).get(0);
50+
float[] queryEmbedding = embed(model, new String[] {query}).get(0);
5151
double k = 60;
5252

5353
PreparedStatement queryStmt = conn.prepareStatement(HYBRID_SQL);
@@ -74,7 +74,7 @@ private static ZooModel<String, float[]> loadModel(String id) throws IOException
7474
.loadModel();
7575
}
7676

77-
private static List<float[]> generateEmbeddings(ZooModel<String, float[]> model, String[] input) throws TranslateException {
77+
private static List<float[]> embed(ZooModel<String, float[]> model, String[] input) throws TranslateException {
7878
Predictor<String, float[]> predictor = model.newPredictor();
7979
List<float[]> embeddings = new ArrayList<>(input.length);
8080
for (String text : input) {

examples/openai/src/main/java/com/example/Example.java

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ public static void main(String[] args) throws IOException, InterruptedException,
4444
"The cat is purring",
4545
"The bear is growling"
4646
};
47-
List<float[]> embeddings = fetchEmbeddings(input, apiKey);
47+
List<float[]> embeddings = embed(input, apiKey);
4848

4949
for (int i = 0; i < input.length; i++) {
5050
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO documents (content, embedding) VALUES (?, ?)");
@@ -53,10 +53,10 @@ public static void main(String[] args) throws IOException, InterruptedException,
5353
insertStmt.executeUpdate();
5454
}
5555

56-
long documentId = 2;
57-
PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM documents WHERE id != ? ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = ?) LIMIT 5");
58-
neighborStmt.setObject(1, documentId);
59-
neighborStmt.setObject(2, documentId);
56+
String query = "forest";
57+
float[] queryEmbedding = embed(new String[] {query}, apiKey).get(0);
58+
PreparedStatement neighborStmt = conn.prepareStatement("SELECT content FROM documents ORDER BY embedding <=> ? LIMIT 5");
59+
neighborStmt.setObject(1, new PGvector(queryEmbedding));
6060
ResultSet rs = neighborStmt.executeQuery();
6161
while (rs.next()) {
6262
System.out.println(rs.getString("content"));
@@ -65,7 +65,7 @@ public static void main(String[] args) throws IOException, InterruptedException,
6565
conn.close();
6666
}
6767

68-
private static List<float[]> fetchEmbeddings(String[] input, String apiKey) throws IOException, InterruptedException {
68+
private static List<float[]> embed(String[] input, String apiKey) throws IOException, InterruptedException {
6969
ObjectMapper mapper = new ObjectMapper();
7070
ObjectNode root = mapper.createObjectNode();
7171
for (String v : input) {

examples/sparse/src/main/java/com/example/Example.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public static void main(String[] args) throws IOException, InterruptedException,
4848
"The cat is purring",
4949
"The bear is growling"
5050
};
51-
List<Map<Integer, Float>> embeddings = fetchEmbeddings(input);
51+
List<Map<Integer, Float>> embeddings = embed(input);
5252

5353
for (int i = 0; i < input.length; i++) {
5454
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO documents (content, embedding) VALUES (?, ?)");
@@ -58,7 +58,7 @@ public static void main(String[] args) throws IOException, InterruptedException,
5858
}
5959

6060
String query = "forest";
61-
Map<Integer, Float> queryEmbedding = fetchEmbeddings(new String[] { query }).get(0);
61+
Map<Integer, Float> queryEmbedding = embed(new String[] { query }).get(0);
6262
PreparedStatement neighborStmt = conn.prepareStatement("SELECT content FROM documents ORDER BY embedding <#> ? LIMIT 5");
6363
neighborStmt.setObject(1, new PGsparsevec(queryEmbedding, 30522));
6464
ResultSet rs = neighborStmt.executeQuery();
@@ -69,7 +69,7 @@ public static void main(String[] args) throws IOException, InterruptedException,
6969
conn.close();
7070
}
7171

72-
private static List<Map<Integer, Float>> fetchEmbeddings(String[] inputs) throws IOException, InterruptedException {
72+
private static List<Map<Integer, Float>> embed(String[] inputs) throws IOException, InterruptedException {
7373
ObjectMapper mapper = new ObjectMapper();
7474
ObjectNode root = mapper.createObjectNode();
7575
for (String v : inputs) {

0 commit comments

Comments
 (0)