Skip to content

Commit b7bb910

Browse files
committed
Added support for bit type
1 parent 5e18949 commit b7bb910

4 files changed

Lines changed: 202 additions & 1 deletion

File tree

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
## 0.1.5 (unreleased)
22

3-
- Added support for `halfvec` and `sparsevec` types
3+
- Added support for `halfvec`, `bit`, and `sparsevec` types
44

55
## 0.1.4 (2023-12-08)
66

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package com.pgvector;
2+
3+
import java.io.Serializable;
4+
import java.sql.Connection;
5+
import java.sql.SQLException;
6+
import java.util.Arrays;
7+
import java.util.List;
8+
import java.util.Objects;
9+
import org.postgresql.PGConnection;
10+
import org.postgresql.util.ByteConverter;
11+
import org.postgresql.util.PGBinaryObject;
12+
import org.postgresql.util.PGobject;
13+
14+
/**
15+
* PGbit class
16+
*/
17+
public class PGbit extends PGobject implements PGBinaryObject, Serializable, Cloneable {
18+
private int length;
19+
private byte[] data;
20+
21+
/**
22+
* Constructor
23+
*/
24+
public PGbit() {
25+
type = "bit";
26+
}
27+
28+
/**
29+
* Constructor
30+
*
31+
* @param v boolean array
32+
*/
33+
public PGbit(boolean[] v) {
34+
this();
35+
length = v.length;
36+
data = new byte[(length + 7) / 8];
37+
for (int i = 0; i < length; i++) {
38+
data[i / 8] |= (v[i] ? 1 : 0) << (7 - (i % 8));
39+
}
40+
}
41+
42+
/**
43+
* Constructor
44+
*
45+
* @param s text representation of a bit string
46+
* @throws SQLException exception
47+
*/
48+
public PGbit(String s) throws SQLException {
49+
this();
50+
setValue(s);
51+
}
52+
53+
/**
54+
* Sets the value from a text representation of a bit string
55+
*/
56+
public void setValue(String s) throws SQLException {
57+
if (s == null) {
58+
data = null;
59+
} else {
60+
length = s.length();
61+
data = new byte[(length + 7) / 8];
62+
for (int i = 0; i < length; i++) {
63+
data[i / 8] |= (s.charAt(i) != '0' ? 1 : 0) << (7 - (i % 8));
64+
}
65+
}
66+
}
67+
68+
/**
69+
* Returns the text representation of a bit string
70+
*/
71+
public String getValue() {
72+
if (data == null) {
73+
return null;
74+
} else {
75+
StringBuilder sb = new StringBuilder(length);
76+
for (int i = 0; i < length; i++) {
77+
sb.append(((data[i / 8] >> (7 - (i % 8))) & 1) == 1 ? '1' : '0');
78+
}
79+
return sb.toString();
80+
}
81+
}
82+
83+
/**
84+
* Returns the number of bytes for the binary representation
85+
*/
86+
public int lengthInBytes() {
87+
return data == null ? 0 : 4 + data.length;
88+
}
89+
90+
/**
91+
* Sets the value from a binary representation of a bit string
92+
*/
93+
public void setByteValue(byte[] value, int offset) throws SQLException {
94+
length = ByteConverter.int4(value, offset);
95+
data = new byte[(length + 7) / 8];
96+
for (int i = 0; i < data.length; i++) {
97+
data[i] = value[offset + 4 + i];
98+
}
99+
}
100+
101+
/**
102+
* Writes the binary representation of a bit string
103+
*/
104+
public void toBytes(byte[] bytes, int offset) {
105+
if (data == null) {
106+
return;
107+
}
108+
109+
ByteConverter.int4(bytes, offset, length);
110+
for (int i = 0; i < data.length; i++) {
111+
bytes[offset + 4 + i] = data[i];
112+
}
113+
}
114+
115+
/**
116+
* Registers the bit type
117+
*
118+
* @param conn connection
119+
* @throws SQLException exception
120+
*/
121+
public static void addBitType(Connection conn) throws SQLException {
122+
conn.unwrap(PGConnection.class).addDataType("bit", PGbit.class);
123+
}
124+
}

src/test/java/com/pgvector/JDBCJavaTest.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import org.junit.jupiter.api.Test;
1414

1515
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
16+
import static org.junit.jupiter.api.Assertions.assertEquals;
1617
import static org.junit.jupiter.api.Assertions.assertNull;
1718

1819
public class JDBCJavaTest {
@@ -117,6 +118,59 @@ void halfvecExample(boolean readBinary) throws SQLException {
117118
assertNull(embeddings.get(3));
118119
}
119120

121+
@Test
122+
void testBitReadText() throws SQLException {
123+
bitExample(false);
124+
}
125+
126+
@Test
127+
void testBitReadBinary() throws SQLException {
128+
bitExample(true);
129+
}
130+
131+
void bitExample(boolean readBinary) throws SQLException {
132+
Connection conn = DriverManager.getConnection("jdbc:postgresql://localhost:5432/pgvector_java_test");
133+
if (readBinary) {
134+
conn.unwrap(PGConnection.class).setPrepareThreshold(-1);
135+
}
136+
137+
Statement setupStmt = conn.createStatement();
138+
setupStmt.executeUpdate("CREATE EXTENSION IF NOT EXISTS vector");
139+
setupStmt.executeUpdate("DROP TABLE IF EXISTS jdbc_items");
140+
141+
PGbit.addBitType(conn);
142+
143+
Statement createStmt = conn.createStatement();
144+
createStmt.executeUpdate("CREATE TABLE jdbc_items (id bigserial PRIMARY KEY, embedding bit(9))");
145+
146+
PreparedStatement insertStmt = conn.prepareStatement("INSERT INTO jdbc_items (embedding) VALUES (?), (?), (?), (?)");
147+
insertStmt.setObject(1, new PGbit(new boolean[] {false, false, false, false, false, false, false, false, false}));
148+
insertStmt.setObject(2, new PGbit(new boolean[] {false, true, false, true, false, false, false, false, true}));
149+
insertStmt.setObject(3, new PGbit(new boolean[] {false, true, true, true, false, false, false, false, true}));
150+
insertStmt.setObject(4, null);
151+
insertStmt.executeUpdate();
152+
153+
PreparedStatement neighborStmt = conn.prepareStatement("SELECT * FROM jdbc_items ORDER BY embedding <~> ? LIMIT 5");
154+
neighborStmt.setObject(1, new PGbit(new boolean[] {false, true, false, true, false, false, false, false, true}));
155+
ResultSet rs = neighborStmt.executeQuery();
156+
List<Long> ids = new ArrayList<>();
157+
List<PGbit> embeddings = new ArrayList<>();
158+
while (rs.next()) {
159+
ids.add(rs.getLong("id"));
160+
embeddings.add((PGbit) rs.getObject("embedding"));
161+
}
162+
assertArrayEquals(new Long[] {2L, 3L, 1L, 4L}, ids.toArray());
163+
assertEquals("010100001", embeddings.get(0).getValue());
164+
assertEquals("011100001", embeddings.get(1).getValue());
165+
assertEquals("000000000", embeddings.get(2).getValue());
166+
assertNull(embeddings.get(3));
167+
168+
Statement indexStmt = conn.createStatement();
169+
indexStmt.executeUpdate("CREATE INDEX ON jdbc_items USING ivfflat (embedding bit_hamming_ops) WITH (lists = 100)");
170+
171+
conn.close();
172+
}
173+
120174
@Test
121175
void testSparsevecReadText() throws SQLException {
122176
sparsevecExample(false);
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package com.pgvector;
2+
3+
import java.sql.SQLException;
4+
import java.util.Arrays;
5+
import com.pgvector.PGbit;
6+
import org.junit.jupiter.api.Test;
7+
8+
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
9+
import static org.junit.jupiter.api.Assertions.assertEquals;
10+
11+
public class PGbitTest {
12+
@Test
13+
void testArrayConstructor() {
14+
PGbit vec = new PGbit(new boolean[] {true, false, true});
15+
assertEquals("101", vec.getValue());
16+
}
17+
18+
@Test
19+
void testStringConstructor() throws SQLException {
20+
PGbit vec = new PGbit("101");
21+
assertEquals("101", vec.getValue());
22+
}
23+
}

0 commit comments

Comments
 (0)