Skip to content

Commit c8ea95d

Browse files
ankanetkwlsrl
andcommitted
Added map constructor for PGsparsevec - #10
Co-authored-by: tkwlsrl <chan@ggaman.com>
1 parent 6352966 commit c8ea95d

2 files changed

Lines changed: 46 additions & 0 deletions

File tree

src/main/java/com/pgvector/PGsparsevec.java

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import java.sql.Connection;
55
import java.sql.SQLException;
66
import java.util.Arrays;
7+
import java.util.ArrayList;
78
import java.util.List;
9+
import java.util.Map;
810
import java.util.Objects;
911
import org.postgresql.PGConnection;
1012
import org.postgresql.util.ByteConverter;
@@ -92,6 +94,36 @@ public <T extends Number> PGsparsevec(List<T> v) {
9294
}
9395
}
9496

97+
/**
98+
* Constructor
99+
*
100+
* @param map <Integer, T> map of non-zero elements
101+
* @param dimensions number of dimensions
102+
*/
103+
public <T extends Number> PGsparsevec(Map<Integer, T> map, int dimensions) {
104+
this();
105+
106+
ArrayList<Map.Entry<Integer, T>> elements = new ArrayList<Map.Entry<Integer, T>>();
107+
if (!Objects.isNull(map)) {
108+
elements.addAll(map.entrySet());
109+
}
110+
elements.removeIf((e) -> e.getValue().floatValue() == 0);
111+
elements.sort((a, b) -> Integer.compare(a.getKey(), b.getKey()));
112+
113+
int nnz = elements.size();
114+
indices = new int[nnz];
115+
values = new float[nnz];
116+
117+
int i = 0;
118+
for (Map.Entry<Integer, T> e : elements) {
119+
indices[i] = e.getKey().intValue();
120+
values[i] = e.getValue().floatValue();
121+
i++;
122+
}
123+
124+
this.dimensions = dimensions;
125+
}
126+
95127
/**
96128
* Constructor
97129
*

src/test/java/com/pgvector/PGsparsevecTest.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import java.sql.SQLException;
44
import java.util.Arrays;
5+
import java.util.HashMap;
6+
import java.util.Map;
57
import com.pgvector.PGsparsevec;
68
import org.junit.jupiter.api.Test;
79

@@ -35,6 +37,18 @@ void testDoubleListConstructor() {
3537
assertArrayEquals(new float[] {1, 2, 3}, vec.toArray());
3638
}
3739

40+
@Test
41+
void testMapConstructor() {
42+
Map<Integer, Float> map = new HashMap<Integer, Float>();
43+
map.put(Integer.valueOf(2), Float.valueOf(2));
44+
map.put(Integer.valueOf(4), Float.valueOf(3));
45+
map.put(Integer.valueOf(0), Float.valueOf(1));
46+
map.put(Integer.valueOf(3), Float.valueOf(0));
47+
PGsparsevec vec = new PGsparsevec(map, 6);
48+
assertArrayEquals(new float[] {1, 0, 2, 0, 3, 0}, vec.toArray());
49+
assertArrayEquals(new int[] {0, 2, 4}, vec.getIndices());
50+
}
51+
3852
@Test
3953
void testGetValue() {
4054
PGsparsevec vec = new PGsparsevec(new float[] {1, 0, 2, 0, 3, 0});

0 commit comments

Comments
 (0)