From 650b4d2de45393319e738d9a763c4a28bd635d64 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Mon, 25 Nov 2024 11:52:41 -0600 Subject: [PATCH] Store compressed vectors in dense ByteSequence for PQVectors --- .../github/jbellis/jvector/pq/PQVectors.java | 102 ++++++++---- .../jvector/pq/ProductQuantization.java | 2 +- .../jvector/vector/ArrayByteSequence.java | 13 ++ .../vector/ArraySliceByteSequence.java | 148 ++++++++++++++++++ .../jvector/vector/types/ByteSequence.java | 4 + .../jbellis/jvector/example/SiftSmall.java | 10 +- .../vector/MemorySegmentByteSequence.java | 15 ++ .../vector/NativeVectorUtilSupport.java | 4 + .../jvector/pq/TestCompressedVectors.java | 1 + .../vector/PanamaVectorUtilSupport.java | 2 +- .../jbellis/jvector/vector/SimdOps.java | 21 +-- 11 files changed, 274 insertions(+), 48 deletions(-) create mode 100644 jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArraySliceByteSequence.java diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java index b58880ab..9d2a688b 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java @@ -28,34 +28,34 @@ import java.io.DataOutput; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.util.Arrays; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; public class PQVectors implements CompressedVectors { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + private static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom final ProductQuantization pq; - private final List> compressedVectors; + private final ByteSequence[] compressedDataChunks; + private final int vectorCount; + private final int vectorsPerChunk; - /** - * Initialize the PQVectors with an initial List of vectors. This list may be - * mutated, but caller is responsible for thread safety issues when doing so. - */ - public PQVectors(ProductQuantization pq, List> compressedVectors) + public PQVectors(ProductQuantization pq, ByteSequence[] compressedDataChunks) { - this.pq = pq; - this.compressedVectors = compressedVectors; + this(pq, compressedDataChunks, compressedDataChunks.length, 1); } - public PQVectors(ProductQuantization pq, ByteSequence[] compressedVectors) + public PQVectors(ProductQuantization pq, ByteSequence[] compressedDataChunks, int vectorCount, int vectorsPerChunk) { - this(pq, List.of(compressedVectors)); + this.pq = pq; + this.compressedDataChunks = compressedDataChunks; + this.vectorCount = vectorCount; + this.vectorsPerChunk = vectorsPerChunk; } @Override public int count() { - return compressedVectors.size(); + return vectorCount; } @Override @@ -65,10 +65,10 @@ public void write(DataOutput out, int version) throws IOException pq.write(out, version); // compressed vectors - out.writeInt(compressedVectors.size()); + out.writeInt(vectorCount); out.writeInt(pq.getSubspaceCount()); - for (var v : compressedVectors) { - vectorTypeSupport.writeByteSequence(out, v); + for (ByteSequence chunk : compressedDataChunks) { + vectorTypeSupport.writeByteSequence(out, chunk); } } @@ -77,24 +77,33 @@ public static PQVectors load(RandomAccessReader in) throws IOException { var pq = ProductQuantization.load(in); // read the vectors - int size = in.readInt(); - if (size < 0) { - throw new IOException("Invalid compressed vector count " + size); + int vectorCount = in.readInt(); + if (vectorCount < 0) { + throw new IOException("Invalid compressed vector count " + vectorCount); } - List> compressedVectors = new ArrayList<>(size); int compressedDimension = in.readInt(); if (compressedDimension < 0) { throw new IOException("Invalid compressed vector dimension " + compressedDimension); } - for (int i = 0; i < size; i++) - { - ByteSequence vector = vectorTypeSupport.readByteSequence(in, compressedDimension); - compressedVectors.add(vector); + // Calculate if we need to split into multiple chunks + long totalSize = (long) vectorCount * compressedDimension; + int vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? vectorCount : MAX_CHUNK_SIZE / compressedDimension; + + int numChunks = vectorCount / vectorsPerChunk; + ByteSequence[] chunks = new ByteSequence[numChunks]; + + for (int i = 0; i < numChunks - 1; i++) { + int chunkSize = vectorsPerChunk * compressedDimension; + chunks[i] = vectorTypeSupport.readByteSequence(in, chunkSize); } - return new PQVectors(pq, compressedVectors); + // Last chunk might be smaller + int remainingVectors = vectorCount - (vectorsPerChunk * (numChunks - 1)); + chunks[numChunks - 1] = vectorTypeSupport.readByteSequence(in, remainingVectors * compressedDimension); + + return new PQVectors(pq, chunks, vectorCount, vectorsPerChunk); } public static PQVectors load(RandomAccessReader in, long offset) throws IOException { @@ -102,6 +111,12 @@ public static PQVectors load(RandomAccessReader in, long offset) throws IOExcept return load(in); } + /** + * We consider two PQVectors equal when their PQs are equal and their compressed data is equal. We ignore the + * chunking strategy in the comparison since this is an implementation detail. + * @param o the object to check for equality + * @return true if the objects are equal, false otherwise + */ @Override public boolean equals(Object o) { if (this == o) return true; @@ -109,12 +124,29 @@ public boolean equals(Object o) { PQVectors that = (PQVectors) o; if (!Objects.equals(pq, that.pq)) return false; - return Objects.equals(compressedVectors, that.compressedVectors); + if (this.count() != that.count()) return false; + // TODO how do we want to determine equality? With the current change, we are willing to write one + // thing and materialize another. It seems like the real concern should be whether the compressedVectors have + // the same data, not whether they are in a MemorySegment or a byte[] and not whether there is one chunk of many + // vectors or many chunks of one vector. This technically goes against the implementation of each of the + // ByteSequence#equals methods, which raises the question of whether this is valid. I primarily updated this + // code to get testSaveLoadPQ to pass. + for (int i = 0; i < this.count(); i++) { + var thisNode = this.get(i); + var thatNode = that.get(i); + if (thisNode.length() != thatNode.length()) return false; + for (int j = 0; j < thisNode.length(); j++) { + if (thisNode.get(j) != thatNode.get(j)) return false; + } + } + return true; } @Override public int hashCode() { - return Objects.hash(pq, compressedVectors); + // We don't use the array structure in the hash code calculation because we allow for different chunking + // strategies. Instead, we use the first entry in the first chunk to provide a stable hash code. + return Objects.hash(pq, count(), count() > 0 ? get(0).get(0) : 0); } @Override @@ -188,7 +220,10 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, } public ByteSequence get(int ordinal) { - return compressedVectors.get(ordinal); + int chunkIndex = ordinal / vectorsPerChunk; + int vectorIndexInChunk = ordinal % vectorsPerChunk; + int start = vectorIndexInChunk * pq.getSubspaceCount(); + return compressedDataChunks[chunkIndex].slice(start, pq.getSubspaceCount()); } public ProductQuantization getProductQuantization() { @@ -225,16 +260,19 @@ public long ramBytesUsed() { int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; long codebooksSize = pq.ramBytesUsed(); - long listSize = (long) REF_BYTES * (1 + compressedVectors.size()); - long dataSize = (long) (OH_BYTES + AH_BYTES + pq.compressedVectorSize()) * compressedVectors.size(); - return codebooksSize + listSize + dataSize; + long chunksArraySize = OH_BYTES + AH_BYTES + (long) compressedDataChunks.length * REF_BYTES; + long dataSize = 0; + for (ByteSequence chunk : compressedDataChunks) { + dataSize += chunk.ramBytesUsed(); + } + return codebooksSize + chunksArraySize + dataSize; } @Override public String toString() { return "PQVectors{" + "pq=" + pq + - ", count=" + compressedVectors.size() + + ", count=" + vectorCount + '}'; } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java index 7b1a8a19..3c17443f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java @@ -221,7 +221,7 @@ public ProductQuantization refine(RandomAccessVectorValues ravv, @Override public CompressedVectors createCompressedVectors(Object[] compressedVectors) { - return new PQVectors(this, (ByteSequence[]) compressedVectors); + return new PQVectors(this, (ByteSequence[]) compressedVectors, compressedVectors.length, 1); } /** diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequence.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequence.java index 47c5413d..c0ee1862 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequence.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequence.java @@ -41,6 +41,11 @@ public byte[] get() { return data; } + @Override + public int offset() { + return 0; + } + @Override public byte get(int n) { return data[n]; @@ -72,6 +77,14 @@ public ArrayByteSequence copy() { return new ArrayByteSequence(Arrays.copyOf(data, data.length)); } + @Override + public ByteSequence slice(int offset, int length) { + if (offset == 0 && length == data.length) { + return this; + } + return new ArraySliceByteSequence(data, offset, length); + } + @Override public long ramBytesUsed() { int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArraySliceByteSequence.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArraySliceByteSequence.java new file mode 100644 index 00000000..470eb44f --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArraySliceByteSequence.java @@ -0,0 +1,148 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.vector; + +import io.github.jbellis.jvector.util.RamUsageEstimator; +import io.github.jbellis.jvector.vector.types.ByteSequence; +import java.util.Arrays; + +/** + * A read only {@link ByteSequence} implementation that wraps an array and provides a view into a slice of it. + */ +public class ArraySliceByteSequence implements ByteSequence { + private final byte[] data; + private final int offset; + private final int length; + + public ArraySliceByteSequence(byte[] data, int offset, int length) { + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException("Invalid offset or length"); + } + this.data = data; + this.offset = offset; + this.length = length; + } + + @Override + public byte[] get() { + return data; + } + + @Override + public int offset() { + return offset; + } + + @Override + public byte get(int n) { + if (n < 0 || n >= length) { + throw new IndexOutOfBoundsException("Index: " + n + ", Length: " + length); + } + return data[offset + n]; + } + + @Override + public void set(int n, byte value) { + if (n < 0 || n >= length) { + throw new IndexOutOfBoundsException("Index: " + n + ", Length: " + length); + } + data[offset + n] = value; + } + + @Override + public void setLittleEndianShort(int shortIndex, short value) { + throw new UnsupportedOperationException("Not supported on slices"); + } + + @Override + public void zero() { + throw new UnsupportedOperationException("Not supported on slices"); + } + + @Override + public int length() { + return length; + } + + @Override + public ByteSequence copy() { + byte[] newData = Arrays.copyOfRange(data, offset, offset + length); + return new ArrayByteSequence(newData); + } + + @Override + public ByteSequence slice(int sliceOffset, int sliceLength) { + if (sliceOffset < 0 || sliceLength < 0 || sliceOffset + sliceLength > length) { + throw new IllegalArgumentException("Invalid slice parameters"); + } + if (sliceOffset == 0 && sliceLength == length) { + return this; + } + return new ArraySliceByteSequence(data, offset + sliceOffset, sliceLength); + } + + @Override + public long ramBytesUsed() { + // Only count the overhead of this slice object, not the underlying array + // since that's shared and counted elsewhere + return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + (3 * Integer.BYTES); // offset, length, and reference to data + } + + @Override + public void copyFrom(ByteSequence src, int srcOffset, int destOffset, int copyLength) { + throw new UnsupportedOperationException("Not supported on slices"); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("["); + for (int i = 0; i < Math.min(length, 25); i++) { + sb.append(get(i)); + if (i < length - 1) { + sb.append(", "); + } + } + if (length > 25) { + sb.append("..."); + } + sb.append("]"); + return sb.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ArraySliceByteSequence that = (ArraySliceByteSequence) o; + if (this.length != that.length) return false; + for (int i = 0; i < length; i++) { + if (this.get(i) != that.get(i)) return false; + } + return true; + } + + @Override + public int hashCode() { + int result = 1; + for (int i = 0; i < length; i++) { + result = 31 * result + get(i); + } + return result; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java index 1f3c18fd..6c317931 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java @@ -25,6 +25,8 @@ public interface ByteSequence extends Accountable */ T get(); + int offset(); + int length(); byte get(int i); @@ -42,4 +44,6 @@ public interface ByteSequence extends Accountable void copyFrom(ByteSequence src, int srcOffset, int destOffset, int length); ByteSequence copy(); + + ByteSequence slice(int offset, int length); } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java index f6053775..83603cc0 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java @@ -215,8 +215,8 @@ public static void siftDiskAnnLTM(List> baseVectors, List> incrementallyCompressedVectors = new ArrayList<>(); - PQVectors pqv = new PQVectors(pq, incrementallyCompressedVectors); + ByteSequence[] incrementallyCompressedVectors = new ByteSequence[baseVectors.size() * pq.compressedVectorSize()]; + PQVectors pqv = new PQVectors(pq, incrementallyCompressedVectors, baseVectors.size(), 1); BuildScoreProvider bsp = BuildScoreProvider.pqBuildScoreProvider(VectorSimilarityFunction.EUCLIDEAN, pqv); Path indexPath = Files.createTempFile("siftsmall", ".inline"); @@ -232,10 +232,10 @@ public static void siftDiskAnnLTM(List> baseVectors, List v : baseVectors) { + for (int ordinal = 0; ordinal < baseVectors.size(); ordinal++) { + VectorFloat v = baseVectors.get(ordinal); // compress the new vector and add it to the PQVectors (via incrementallyCompressedVectors) - int ordinal = incrementallyCompressedVectors.size(); - incrementallyCompressedVectors.add(pq.encode(v)); + incrementallyCompressedVectors[ordinal] = pq.encode(v); // write the full vector to disk writer.writeInline(ordinal, Feature.singleState(FeatureId.INLINE_VECTORS, new InlineVectors.State(v))); // now add it to the graph -- the previous steps must be completed first since the PQVectors diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java index 4c541856..448477f8 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java @@ -49,6 +49,11 @@ public class MemorySegmentByteSequence implements ByteSequence { this.length = data.length; } + private MemorySegmentByteSequence(MemorySegment segment) { + this.segment = segment; + this.length = Math.toIntExact(segment.byteSize()); + } + @Override public long ramBytesUsed() { int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; @@ -67,6 +72,11 @@ public MemorySegment get() { return segment; } + @Override + public int offset() { + return 0; + } + @Override public byte get(int n) { return segment.get(ValueLayout.JAVA_BYTE, n); @@ -99,6 +109,11 @@ public ByteSequence copy() { return copy; } + @Override + public MemorySegmentByteSequence slice(int offset, int length) { + return new MemorySegmentByteSequence(segment.asSlice(offset, length)); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java index 0af7ebcb..335c1b1c 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/NativeVectorUtilSupport.java @@ -97,6 +97,7 @@ public VectorFloat sub(VectorFloat a, int aOffset, VectorFloat b, int b @Override public float assembleAndSum(VectorFloat data, int dataBase, ByteSequence baseOffsets) { + assert baseOffsets.offset() == 0 : "Base offsets are expected to have an offset of 0. Found: " + baseOffsets.offset(); return NativeSimdOps.assemble_and_sum_f32_512(((MemorySegmentVectorFloat)data).get(), dataBase, ((MemorySegmentByteSequence)baseOffsets).get(), baseOffsets.length()); } @@ -140,6 +141,7 @@ public void quantizePartials(float delta, VectorFloat partials, VectorFloat shuffles, int codebookCount, ByteSequence quantizedPartials, float delta, float bestDistance, VectorSimilarityFunction vsf, VectorFloat results) { + assert shuffles.offset() == 0 : "Bulk shuffle shuffles are expected to have an offset of 0. Found: " + shuffles.offset(); switch (vsf) { case DOT_PRODUCT -> NativeSimdOps.bulk_quantized_shuffle_dot_f32_512(((MemorySegmentByteSequence) shuffles).get(), codebookCount, ((MemorySegmentByteSequence) quantizedPartials).get(), delta, bestDistance, ((MemorySegmentVectorFloat) results).get()); case EUCLIDEAN -> NativeSimdOps.bulk_quantized_shuffle_euclidean_f32_512(((MemorySegmentByteSequence) shuffles).get(), codebookCount, ((MemorySegmentByteSequence) quantizedPartials).get(), delta, bestDistance, ((MemorySegmentVectorFloat) results).get()); @@ -152,6 +154,7 @@ public void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int c ByteSequence quantizedPartialSums, float sumDelta, float minDistance, ByteSequence quantizedPartialSquaredMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, VectorFloat results) { + assert shuffles.offset() == 0 : "Bulk shuffle shuffles are expected to have an offset of 0. Found: " + shuffles.offset(); NativeSimdOps.bulk_quantized_shuffle_cosine_f32_512(((MemorySegmentByteSequence) shuffles).get(), codebookCount, ((MemorySegmentByteSequence) quantizedPartialSums).get(), sumDelta, minDistance, ((MemorySegmentByteSequence) quantizedPartialSquaredMagnitudes).get(), magnitudeDelta, minMagnitude, queryMagnitudeSquared, ((MemorySegmentVectorFloat) results).get()); } @@ -159,6 +162,7 @@ public void bulkShuffleQuantizedSimilarityCosine(ByteSequence shuffles, int c @Override public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { + assert encoded.offset() == 0 : "Bulk shuffle shuffles are expected to have an offset of 0. Found: " + encoded.offset(); return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude); } } diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestCompressedVectors.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestCompressedVectors.java index c357e609..496b2849 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestCompressedVectors.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestCompressedVectors.java @@ -59,6 +59,7 @@ public void testSaveLoadPQ() throws Exception { // Read compressed vectors try (var in = new SimpleMappedReader(cvFile.getAbsolutePath())) { var cv2 = PQVectors.load(in, 0); + assertEquals(cv.hashCode(), cv2.hashCode()); assertEquals(cv, cv2); } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index 4b2602cb..a68fea3a 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -159,7 +159,7 @@ public void quantizePartials(float delta, VectorFloat partials, VectorFloat encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { - return SimdOps.pqDecodedCosineSimilarity((ArrayByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); + return SimdOps.pqDecodedCosineSimilarity((ByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index 034aa987..124f8619 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.vector; +import io.github.jbellis.jvector.vector.types.ByteSequence; import io.github.jbellis.jvector.vector.types.VectorFloat; import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; @@ -660,13 +661,13 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra } } - public static float pqDecodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { return HAS_AVX512 ? pqDecodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude) : pqDecodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); } - public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity512(ByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { var sum = FloatVector.zero(FloatVector.SPECIES_512); var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512); var baseOffsets = encoded.get(); @@ -674,8 +675,9 @@ public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int var aMagnitudeArray = aMagnitude.get(); int[] convOffsets = scratchInt512.get(); - int i = 0; - int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length); + int i = encoded.offset(); + int length = encoded.length(); + int limit = i + ByteVector.SPECIES_128.loopBound(length); var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(clusterCount); @@ -696,7 +698,7 @@ public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int float sumResult = sum.reduceLanes(VectorOperators.ADD); float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); - for (; i < baseOffsets.length; i++) { + for (; i < length; i++) { int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]); sumResult += partialSumsArray[offset]; aMagnitudeResult += aMagnitudeArray[offset]; @@ -705,7 +707,7 @@ public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); } - public static float pqDecodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity256(ByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { var sum = FloatVector.zero(FloatVector.SPECIES_256); var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256); var baseOffsets = encoded.get(); @@ -713,8 +715,9 @@ public static float pqDecodedCosineSimilarity256(ArrayByteSequence encoded, int var aMagnitudeArray = aMagnitude.get(); int[] convOffsets = scratchInt256.get(); - int i = 0; - int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length); + int i = encoded.offset(); + int length = encoded.length(); + int limit = i + ByteVector.SPECIES_64.loopBound(length); var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(clusterCount); @@ -735,7 +738,7 @@ public static float pqDecodedCosineSimilarity256(ArrayByteSequence encoded, int float sumResult = sum.reduceLanes(VectorOperators.ADD); float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); - for (; i < baseOffsets.length; i++) { + for (; i < length; i++) { int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]); sumResult += partialSumsArray[offset]; aMagnitudeResult += aMagnitudeArray[offset];