From 4f1a44b489dcbc617374f0c2bf37df740345d522 Mon Sep 17 00:00:00 2001 From: Michael Marshall Date: Fri, 22 Nov 2024 20:39:27 -0600 Subject: [PATCH] Store compressed vectors in dense ByteSequence for PQVectors --- .../github/jbellis/jvector/pq/PQVectors.java | 84 +++++---- .../jvector/pq/ProductQuantization.java | 2 +- .../jvector/vector/ArrayByteSequence.java | 13 ++ .../vector/ArrayByteSequenceSlice.java | 166 ++++++++++++++++++ .../jvector/vector/types/ByteSequence.java | 4 + .../jbellis/jvector/example/SiftSmall.java | 10 +- .../vector/MemorySegmentByteSequence.java | 15 ++ .../jvector/pq/TestCompressedVectors.java | 14 +- .../vector/PanamaVectorUtilSupport.java | 2 +- .../jbellis/jvector/vector/SimdOps.java | 23 +-- 10 files changed, 282 insertions(+), 51 deletions(-) create mode 100644 jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequenceSlice.java diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java index b58880ab..9ec1686d 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java @@ -28,34 +28,34 @@ import java.io.DataOutput; import java.io.IOException; -import java.util.ArrayList; -import java.util.List; +import java.util.Arrays; import java.util.Objects; import java.util.concurrent.atomic.AtomicReference; public class PQVectors implements CompressedVectors { private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport(); + private static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom final ProductQuantization pq; - private final List> compressedVectors; + private final ByteSequence[] compressedDataChunks; + private final int vectorCount; + private final int vectorsPerChunk; - /** - * Initialize the PQVectors with an initial List of vectors. This list may be - * mutated, but caller is responsible for thread safety issues when doing so. - */ - public PQVectors(ProductQuantization pq, List> compressedVectors) + public PQVectors(ProductQuantization pq, ByteSequence[] compressedDataChunks) { - this.pq = pq; - this.compressedVectors = compressedVectors; + this(pq, compressedDataChunks, compressedDataChunks.length, 1); } - public PQVectors(ProductQuantization pq, ByteSequence[] compressedVectors) + public PQVectors(ProductQuantization pq, ByteSequence[] compressedDataChunks, int vectorCount, int vectorsPerChunk) { - this(pq, List.of(compressedVectors)); + this.pq = pq; + this.compressedDataChunks = compressedDataChunks; + this.vectorCount = vectorCount; + this.vectorsPerChunk = vectorsPerChunk; } @Override public int count() { - return compressedVectors.size(); + return vectorCount; } @Override @@ -65,10 +65,10 @@ public void write(DataOutput out, int version) throws IOException pq.write(out, version); // compressed vectors - out.writeInt(compressedVectors.size()); + out.writeInt(vectorCount); out.writeInt(pq.getSubspaceCount()); - for (var v : compressedVectors) { - vectorTypeSupport.writeByteSequence(out, v); + for (ByteSequence chunk : compressedDataChunks) { + vectorTypeSupport.writeByteSequence(out, chunk); } } @@ -77,24 +77,33 @@ public static PQVectors load(RandomAccessReader in) throws IOException { var pq = ProductQuantization.load(in); // read the vectors - int size = in.readInt(); - if (size < 0) { - throw new IOException("Invalid compressed vector count " + size); + int vectorCount = in.readInt(); + if (vectorCount < 0) { + throw new IOException("Invalid compressed vector count " + vectorCount); } - List> compressedVectors = new ArrayList<>(size); int compressedDimension = in.readInt(); if (compressedDimension < 0) { throw new IOException("Invalid compressed vector dimension " + compressedDimension); } - for (int i = 0; i < size; i++) - { - ByteSequence vector = vectorTypeSupport.readByteSequence(in, compressedDimension); - compressedVectors.add(vector); + // Calculate if we need to split into multiple chunks + long totalSize = (long) vectorCount * compressedDimension; + int vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? vectorCount : MAX_CHUNK_SIZE / compressedDimension; + + int numChunks = vectorCount / vectorsPerChunk; + ByteSequence[] chunks = new ByteSequence[numChunks]; + + for (int i = 0; i < numChunks - 1; i++) { + int chunkSize = vectorsPerChunk * compressedDimension; + chunks[i] = vectorTypeSupport.readByteSequence(in, chunkSize); } - return new PQVectors(pq, compressedVectors); + // Last chunk might be smaller + int remainingVectors = vectorCount - (vectorsPerChunk * (numChunks - 1)); + chunks[numChunks - 1] = vectorTypeSupport.readByteSequence(in, remainingVectors * compressedDimension); + + return new PQVectors(pq, chunks, vectorCount, vectorsPerChunk); } public static PQVectors load(RandomAccessReader in, long offset) throws IOException { @@ -109,12 +118,17 @@ public boolean equals(Object o) { PQVectors that = (PQVectors) o; if (!Objects.equals(pq, that.pq)) return false; - return Objects.equals(compressedVectors, that.compressedVectors); + if (vectorsPerChunk != that.vectorsPerChunk) return false; + if (compressedDataChunks.length != that.compressedDataChunks.length) return false; + for (int i = 0; i < compressedDataChunks.length; i++) { + if (!compressedDataChunks[i].equals(that.compressedDataChunks[i])) return false; + } + return true; } @Override public int hashCode() { - return Objects.hash(pq, compressedVectors); + return Objects.hash(pq, Arrays.hashCode(compressedDataChunks)); } @Override @@ -188,7 +202,10 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, } public ByteSequence get(int ordinal) { - return compressedVectors.get(ordinal); + int chunkIndex = ordinal / vectorsPerChunk; + int vectorIndexInChunk = ordinal % vectorsPerChunk; + int start = vectorIndexInChunk * pq.getSubspaceCount(); + return compressedDataChunks[chunkIndex].slice(start, pq.getSubspaceCount()); } public ProductQuantization getProductQuantization() { @@ -225,16 +242,19 @@ public long ramBytesUsed() { int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER; long codebooksSize = pq.ramBytesUsed(); - long listSize = (long) REF_BYTES * (1 + compressedVectors.size()); - long dataSize = (long) (OH_BYTES + AH_BYTES + pq.compressedVectorSize()) * compressedVectors.size(); - return codebooksSize + listSize + dataSize; + long chunksArraySize = OH_BYTES + AH_BYTES + (long) compressedDataChunks.length * REF_BYTES; + long dataSize = 0; + for (ByteSequence chunk : compressedDataChunks) { + dataSize += chunk.ramBytesUsed(); + } + return codebooksSize + chunksArraySize + dataSize; } @Override public String toString() { return "PQVectors{" + "pq=" + pq + - ", count=" + compressedVectors.size() + + ", count=" + vectorCount + '}'; } } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java index 7b1a8a19..3c17443f 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/pq/ProductQuantization.java @@ -221,7 +221,7 @@ public ProductQuantization refine(RandomAccessVectorValues ravv, @Override public CompressedVectors createCompressedVectors(Object[] compressedVectors) { - return new PQVectors(this, (ByteSequence[]) compressedVectors); + return new PQVectors(this, (ByteSequence[]) compressedVectors, compressedVectors.length, 1); } /** diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequence.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequence.java index 47c5413d..24a2acae 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequence.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequence.java @@ -41,6 +41,11 @@ public byte[] get() { return data; } + @Override + public int offset() { + return 0; + } + @Override public byte get(int n) { return data[n]; @@ -72,6 +77,14 @@ public ArrayByteSequence copy() { return new ArrayByteSequence(Arrays.copyOf(data, data.length)); } + @Override + public ByteSequence slice(int offset, int length) { + if (offset == 0 && length == data.length) { + return this; + } + return new ArrayByteSequenceSlice(data, offset, length); + } + @Override public long ramBytesUsed() { int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequenceSlice.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequenceSlice.java new file mode 100644 index 00000000..5946b149 --- /dev/null +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/ArrayByteSequenceSlice.java @@ -0,0 +1,166 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.github.jbellis.jvector.vector; + +import io.github.jbellis.jvector.util.RamUsageEstimator; +import io.github.jbellis.jvector.vector.types.ByteSequence; +import java.nio.ByteBuffer; +import java.util.Arrays; +import java.nio.ByteOrder; + +public class ArrayByteSequenceSlice implements ByteSequence { + private final byte[] data; + private final int offset; + private final int length; + + public ArrayByteSequenceSlice(byte[] data, int offset, int length) { + if (offset < 0 || length < 0 || offset + length > data.length) { + throw new IllegalArgumentException("Invalid offset or length"); + } + this.data = data; + this.offset = offset; + this.length = length; + } + + @Override + public byte[] get() { + return data; + } + + @Override + public int offset() { + return offset; + } + + @Override + public byte get(int n) { + if (n < 0 || n >= length) { + throw new IndexOutOfBoundsException("Index: " + n + ", Length: " + length); + } + return data[offset + n]; + } + + @Override + public void set(int n, byte value) { + if (n < 0 || n >= length) { + throw new IndexOutOfBoundsException("Index: " + n + ", Length: " + length); + } + data[offset + n] = value; + } + + @Override + public void setLittleEndianShort(int shortIndex, short value) { + int baseIndex = offset + (shortIndex * 2); + if (baseIndex < offset || baseIndex + 1 >= offset + length) { + throw new IndexOutOfBoundsException("Short index out of bounds: " + shortIndex); + } + data[baseIndex] = (byte) (value & 0xFF); + data[baseIndex + 1] = (byte) ((value >> 8) & 0xFF); + } + + @Override + public void zero() { + Arrays.fill(data, offset, offset + length, (byte) 0); + } + + @Override + public int length() { + return length; + } + + @Override + public ArrayByteSequenceSlice copy() { + byte[] newData = Arrays.copyOfRange(data, offset, offset + length); + return new ArrayByteSequenceSlice(newData, 0, length); + } + + @Override + public ByteSequence slice(int sliceOffset, int sliceLength) { + if (sliceOffset < 0 || sliceLength < 0 || sliceOffset + sliceLength > length) { + throw new IllegalArgumentException("Invalid slice parameters"); + } + if (sliceOffset == 0 && sliceLength == length) { + return this; + } + return new ArrayByteSequenceSlice(data, offset + sliceOffset, sliceLength); + } + + @Override + public long ramBytesUsed() { + // Only count the overhead of this slice object, not the underlying array + // since that's shared and counted elsewhere + return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER + + (3 * Integer.BYTES); // offset, length, and reference to data + } + + @Override + public void copyFrom(ByteSequence src, int srcOffset, int destOffset, int copyLength) { + if (destOffset < 0 || destOffset + copyLength > length) { + throw new IndexOutOfBoundsException("Destination range out of bounds"); + } + if (src instanceof ArrayByteSequence) { + ArrayByteSequence csrc = (ArrayByteSequence) src; + System.arraycopy(csrc.get(), srcOffset, data, this.offset + destOffset, copyLength); + } else if (src instanceof ArrayByteSequenceSlice) { + ArrayByteSequenceSlice csrc = (ArrayByteSequenceSlice) src; + System.arraycopy(csrc.data, csrc.offset + srcOffset, data, this.offset + destOffset, copyLength); + } else { + // Fallback for other implementations + for (int i = 0; i < copyLength; i++) { + set(destOffset + i, src.get(srcOffset + i)); + } + } + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("["); + for (int i = 0; i < Math.min(length, 25); i++) { + sb.append(get(i)); + if (i < length - 1) { + sb.append(", "); + } + } + if (length > 25) { + sb.append("..."); + } + sb.append("]"); + return sb.toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ArrayByteSequenceSlice that = (ArrayByteSequenceSlice) o; + if (this.length != that.length) return false; + for (int i = 0; i < length; i++) { + if (this.get(i) != that.get(i)) return false; + } + return true; + } + + @Override + public int hashCode() { + int result = 1; + for (int i = 0; i < length; i++) { + result = 31 * result + get(i); + } + return result; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java index 1f3c18fd..6c317931 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/types/ByteSequence.java @@ -25,6 +25,8 @@ public interface ByteSequence extends Accountable */ T get(); + int offset(); + int length(); byte get(int i); @@ -42,4 +44,6 @@ public interface ByteSequence extends Accountable void copyFrom(ByteSequence src, int srcOffset, int destOffset, int length); ByteSequence copy(); + + ByteSequence slice(int offset, int length); } diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java index f6053775..83603cc0 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/SiftSmall.java @@ -215,8 +215,8 @@ public static void siftDiskAnnLTM(List> baseVectors, List> incrementallyCompressedVectors = new ArrayList<>(); - PQVectors pqv = new PQVectors(pq, incrementallyCompressedVectors); + ByteSequence[] incrementallyCompressedVectors = new ByteSequence[baseVectors.size() * pq.compressedVectorSize()]; + PQVectors pqv = new PQVectors(pq, incrementallyCompressedVectors, baseVectors.size(), 1); BuildScoreProvider bsp = BuildScoreProvider.pqBuildScoreProvider(VectorSimilarityFunction.EUCLIDEAN, pqv); Path indexPath = Files.createTempFile("siftsmall", ".inline"); @@ -232,10 +232,10 @@ public static void siftDiskAnnLTM(List> baseVectors, List v : baseVectors) { + for (int ordinal = 0; ordinal < baseVectors.size(); ordinal++) { + VectorFloat v = baseVectors.get(ordinal); // compress the new vector and add it to the PQVectors (via incrementallyCompressedVectors) - int ordinal = incrementallyCompressedVectors.size(); - incrementallyCompressedVectors.add(pq.encode(v)); + incrementallyCompressedVectors[ordinal] = pq.encode(v); // write the full vector to disk writer.writeInline(ordinal, Feature.singleState(FeatureId.INLINE_VECTORS, new InlineVectors.State(v))); // now add it to the graph -- the previous steps must be completed first since the PQVectors diff --git a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java index 4c541856..448477f8 100644 --- a/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java +++ b/jvector-native/src/main/java/io/github/jbellis/jvector/vector/MemorySegmentByteSequence.java @@ -49,6 +49,11 @@ public class MemorySegmentByteSequence implements ByteSequence { this.length = data.length; } + private MemorySegmentByteSequence(MemorySegment segment) { + this.segment = segment; + this.length = Math.toIntExact(segment.byteSize()); + } + @Override public long ramBytesUsed() { int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER; @@ -67,6 +72,11 @@ public MemorySegment get() { return segment; } + @Override + public int offset() { + return 0; + } + @Override public byte get(int n) { return segment.get(ValueLayout.JAVA_BYTE, n); @@ -99,6 +109,11 @@ public ByteSequence copy() { return copy; } + @Override + public MemorySegmentByteSequence slice(int offset, int length) { + return new MemorySegmentByteSequence(segment.asSlice(offset, length)); + } + @Override public String toString() { StringBuilder sb = new StringBuilder(); diff --git a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestCompressedVectors.java b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestCompressedVectors.java index c357e609..e1ed45f8 100644 --- a/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestCompressedVectors.java +++ b/jvector-tests/src/test/java/io/github/jbellis/jvector/pq/TestCompressedVectors.java @@ -59,7 +59,19 @@ public void testSaveLoadPQ() throws Exception { // Read compressed vectors try (var in = new SimpleMappedReader(cvFile.getAbsolutePath())) { var cv2 = PQVectors.load(in, 0); - assertEquals(cv, cv2); + // TODO how do we want to determine equality? With the current change, we are willing to write one + // thing and materialize another as long as the external representation is the same. The equality + // method on PQ (and on the underlying ByteSequence implementations) do not take this same view though. + // So which is correct? Here, I made the test pass by asserting on the contents of the PQVectors. + assertEquals(cv.count(), cv2.count()); + for (int i = 0; i < cv.count(); i++) { + var node1 = cv.get(i); + var node2 = cv2.get(i); + assertEquals(node1.length(), node2.length()); + for (int j = 0; j < node1.length(); j++) { + assertEquals(node1.get(j), node2.get(j)); + } + } } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index 4b2602cb..a68fea3a 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -159,7 +159,7 @@ public void quantizePartials(float delta, VectorFloat partials, VectorFloat encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { - return SimdOps.pqDecodedCosineSimilarity((ArrayByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); + return SimdOps.pqDecodedCosineSimilarity((ByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude); } } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java index 034aa987..7df54721 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/SimdOps.java @@ -16,6 +16,7 @@ package io.github.jbellis.jvector.vector; +import io.github.jbellis.jvector.vector.types.ByteSequence; import io.github.jbellis.jvector.vector.types.VectorFloat; import jdk.incubator.vector.ByteVector; import jdk.incubator.vector.FloatVector; @@ -660,13 +661,13 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra } } - public static float pqDecodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { return HAS_AVX512 ? pqDecodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude) : pqDecodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude); } - public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity512(ByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { var sum = FloatVector.zero(FloatVector.SPECIES_512); var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512); var baseOffsets = encoded.get(); @@ -674,8 +675,8 @@ public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int var aMagnitudeArray = aMagnitude.get(); int[] convOffsets = scratchInt512.get(); - int i = 0; - int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length); + int i = encoded.offset(); + int limit = i + ByteVector.SPECIES_128.loopBound(encoded.length()); var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(clusterCount); @@ -696,8 +697,8 @@ public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int float sumResult = sum.reduceLanes(VectorOperators.ADD); float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); - for (; i < baseOffsets.length; i++) { - int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]); + for (; i < encoded.length(); i++) { + int offset = clusterCount * i + Byte.toUnsignedInt(encoded.get(i)); sumResult += partialSumsArray[offset]; aMagnitudeResult += aMagnitudeArray[offset]; } @@ -705,7 +706,7 @@ public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude)); } - public static float pqDecodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { + public static float pqDecodedCosineSimilarity256(ByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) { var sum = FloatVector.zero(FloatVector.SPECIES_256); var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256); var baseOffsets = encoded.get(); @@ -713,8 +714,8 @@ public static float pqDecodedCosineSimilarity256(ArrayByteSequence encoded, int var aMagnitudeArray = aMagnitude.get(); int[] convOffsets = scratchInt256.get(); - int i = 0; - int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length); + int i = encoded.offset(); + int limit = i + ByteVector.SPECIES_64.loopBound(encoded.length()); var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(clusterCount); @@ -735,8 +736,8 @@ public static float pqDecodedCosineSimilarity256(ArrayByteSequence encoded, int float sumResult = sum.reduceLanes(VectorOperators.ADD); float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD); - for (; i < baseOffsets.length; i++) { - int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]); + for (; i < encoded.length(); i++) { + int offset = clusterCount * i + Byte.toUnsignedInt(encoded.get(i)); sumResult += partialSumsArray[offset]; aMagnitudeResult += aMagnitudeArray[offset]; }