Skip to content

Commit 47bdc68

Browse files
committed
Added Disco example [skip ci]
1 parent 3d5c002 commit 47bdc68

3 files changed

Lines changed: 126 additions & 0 deletions

File tree

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Or check out some examples:
4141
- [Hybrid search](examples/hybrid/src/main/java/com/example/Example.java) with Deep Java Library (Reciprocal Rank Fusion)
4242
- [Sparse search](examples/sparse/src/main/java/com/example/Example.java) with Text Embeddings Inference
4343
- [Extended-connectivity fingerprints](examples/cdk/src/main/java/com/example/Example.java) with the Chemistry Development Kit
44+
- [Recommendations](examples/disco/src/main/java/com/example/Example.java) with Disco
4445
- [Horizontal scaling](examples/citus/src/main/java/com/example/Example.java) with Citus
4546
- [Bulk loading](examples/loading/src/main/java/com/example/Example.java) with `COPY`
4647

examples/disco/pom.xml

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
<?xml version='1.0' encoding='UTF-8'?>
2+
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns="http://maven.apache.org/POM/4.0.0">
3+
<modelVersion>4.0.0</modelVersion>
4+
<groupId>com.example</groupId>
5+
<artifactId>example</artifactId>
6+
<version>1</version>
7+
<properties>
8+
<maven.compiler.release>11</maven.compiler.release>
9+
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
10+
</properties>
11+
<dependencies>
12+
<dependency>
13+
<groupId>org.postgresql</groupId>
14+
<artifactId>postgresql</artifactId>
15+
<version>42.7.3</version>
16+
</dependency>
17+
<dependency>
18+
<groupId>com.pgvector</groupId>
19+
<artifactId>pgvector</artifactId>
20+
<version>0.1.6</version>
21+
</dependency>
22+
<dependency>
23+
<groupId>org.ankane</groupId>
24+
<artifactId>disco</artifactId>
25+
<version>0.1.0</version>
26+
</dependency>
27+
</dependencies>
28+
<build>
29+
<plugins>
30+
<plugin>
31+
<artifactId>maven-assembly-plugin</artifactId>
32+
<version>3.7.1</version>
33+
<configuration>
34+
<descriptorRefs>
35+
<descriptorRef>jar-with-dependencies</descriptorRef>
36+
</descriptorRefs>
37+
<archive>
38+
<manifest>
39+
<mainClass>com.example.Example</mainClass>
40+
</manifest>
41+
</archive>
42+
<finalName>example</finalName>
43+
</configuration>
44+
<executions>
45+
<execution>
46+
<id>make-assembly</id>
47+
<phase>package</phase>
48+
<goals>
49+
<goal>single</goal>
50+
</goals>
51+
</execution>
52+
</executions>
53+
</plugin>
54+
</plugins>
55+
</build>
56+
</project>
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package com.example;
2+
3+
import java.sql.Connection;
4+
import java.sql.DriverManager;
5+
import java.sql.PreparedStatement;
6+
import java.sql.ResultSet;
7+
import java.sql.SQLException;
8+
import java.sql.Statement;
9+
import java.util.ArrayList;
10+
import com.pgvector.PGvector;
11+
import org.ankane.disco.Data;
12+
import org.ankane.disco.Dataset;
13+
import org.ankane.disco.Recommender;
14+
15+
public class Example {
16+
public static void main(String[] args) throws Exception {
17+
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_example");
18+
Statement setupStmt = conn.createStatement();
19+
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
20+
PGvector.addVectorType(conn);
21+
22+
Statement createStmt = conn.createStatement();
23+
createStmt.executeUpdate("DROP TABLE IF EXISTS users");
24+
createStmt.executeUpdate("DROP TABLE IF EXISTS movies");
25+
createStmt.executeUpdate("CREATE TABLE users (id integer PRIMARY KEY, factors vector(20))");
26+
createStmt.executeUpdate("CREATE TABLE movies (name text PRIMARY KEY, factors vector(20))");
27+
28+
Dataset<Integer, String> data = Data.loadMovieLens();
29+
Recommender<Integer, String> recommender = Recommender
30+
.builder()
31+
.factors(20)
32+
.fitExplicit(data);
33+
34+
for (Integer userId : recommender.userIds()) {
35+
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO users (id, factors) VALUES (?, ?)");
36+
insertStmt.setInt(1, userId);
37+
insertStmt.setObject(2, new PGvector(recommender.userFactors(userId).get()));
38+
insertStmt.executeUpdate();
39+
}
40+
41+
for (String itemId : recommender.itemIds()) {
42+
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO movies (name, factors) VALUES (?, ?)");
43+
insertStmt.setString(1, itemId);
44+
insertStmt.setObject(2, new PGvector(recommender.itemFactors(itemId).get()));
45+
insertStmt.executeUpdate();
46+
}
47+
48+
String movie = "Star Wars (1977)";
49+
System.out.printf("Item-based recommendations for %s\n", movie);
50+
PreparedStatement neighborStmt = conn.prepareStatement("SELECT name FROM movies WHERE name != ? ORDER BY factors <=> (SELECT factors FROM movies WHERE name = ?) LIMIT 5");
51+
neighborStmt.setString(1, movie);
52+
neighborStmt.setString(2, movie);
53+
ResultSet rs = neighborStmt.executeQuery();
54+
while (rs.next()) {
55+
System.out.println("- " + rs.getString("name"));
56+
}
57+
58+
int userId = 123;
59+
System.out.printf("\nUser-based recommendations for user %d\n", userId);
60+
neighborStmt = conn.prepareStatement("SELECT name FROM movies ORDER BY factors <#> (SELECT factors FROM users WHERE id = ?) LIMIT 5");
61+
neighborStmt.setInt(1, userId);
62+
rs = neighborStmt.executeQuery();
63+
while (rs.next()) {
64+
System.out.println("- " + rs.getString("name"));
65+
}
66+
67+
conn.close();
68+
}
69+
}

0 commit comments

Comments
 (0)