Skip to content

Commit

Permalink
Store compressed vectors in dense ByteSequence for PQVectors
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeljmarshall committed Nov 25, 2024
1 parent da08d40 commit 650b4d2
Show file tree
Hide file tree
Showing 11 changed files with 274 additions and 48 deletions.
102 changes: 70 additions & 32 deletions jvector-base/src/main/java/io/github/jbellis/jvector/pq/PQVectors.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,34 @@

import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Arrays;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

public class PQVectors implements CompressedVectors {
private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
private static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom
final ProductQuantization pq;
private final List<ByteSequence<?>> compressedVectors;
private final ByteSequence<?>[] compressedDataChunks;
private final int vectorCount;
private final int vectorsPerChunk;

/**
* Initialize the PQVectors with an initial List of vectors. This list may be
* mutated, but caller is responsible for thread safety issues when doing so.
*/
public PQVectors(ProductQuantization pq, List<ByteSequence<?>> compressedVectors)
public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedDataChunks)
{
this.pq = pq;
this.compressedVectors = compressedVectors;
this(pq, compressedDataChunks, compressedDataChunks.length, 1);
}

public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedVectors)
public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedDataChunks, int vectorCount, int vectorsPerChunk)
{
this(pq, List.of(compressedVectors));
this.pq = pq;
this.compressedDataChunks = compressedDataChunks;
this.vectorCount = vectorCount;
this.vectorsPerChunk = vectorsPerChunk;
}

@Override
public int count() {
return compressedVectors.size();
return vectorCount;
}

@Override
Expand All @@ -65,10 +65,10 @@ public void write(DataOutput out, int version) throws IOException
pq.write(out, version);

// compressed vectors
out.writeInt(compressedVectors.size());
out.writeInt(vectorCount);
out.writeInt(pq.getSubspaceCount());
for (var v : compressedVectors) {
vectorTypeSupport.writeByteSequence(out, v);
for (ByteSequence<?> chunk : compressedDataChunks) {
vectorTypeSupport.writeByteSequence(out, chunk);
}
}

Expand All @@ -77,44 +77,76 @@ public static PQVectors load(RandomAccessReader in) throws IOException {
var pq = ProductQuantization.load(in);

// read the vectors
int size = in.readInt();
if (size < 0) {
throw new IOException("Invalid compressed vector count " + size);
int vectorCount = in.readInt();
if (vectorCount < 0) {
throw new IOException("Invalid compressed vector count " + vectorCount);
}
List<ByteSequence<?>> compressedVectors = new ArrayList<>(size);

int compressedDimension = in.readInt();
if (compressedDimension < 0) {
throw new IOException("Invalid compressed vector dimension " + compressedDimension);
}

for (int i = 0; i < size; i++)
{
ByteSequence<?> vector = vectorTypeSupport.readByteSequence(in, compressedDimension);
compressedVectors.add(vector);
// Calculate if we need to split into multiple chunks
long totalSize = (long) vectorCount * compressedDimension;
int vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? vectorCount : MAX_CHUNK_SIZE / compressedDimension;

int numChunks = vectorCount / vectorsPerChunk;
ByteSequence<?>[] chunks = new ByteSequence<?>[numChunks];

for (int i = 0; i < numChunks - 1; i++) {
int chunkSize = vectorsPerChunk * compressedDimension;
chunks[i] = vectorTypeSupport.readByteSequence(in, chunkSize);
}

return new PQVectors(pq, compressedVectors);
// Last chunk might be smaller
int remainingVectors = vectorCount - (vectorsPerChunk * (numChunks - 1));
chunks[numChunks - 1] = vectorTypeSupport.readByteSequence(in, remainingVectors * compressedDimension);

return new PQVectors(pq, chunks, vectorCount, vectorsPerChunk);
}

public static PQVectors load(RandomAccessReader in, long offset) throws IOException {
in.seek(offset);
return load(in);
}

/**
* We consider two PQVectors equal when their PQs are equal and their compressed data is equal. We ignore the
* chunking strategy in the comparison since this is an implementation detail.
* @param o the object to check for equality
* @return true if the objects are equal, false otherwise
*/
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

PQVectors that = (PQVectors) o;
if (!Objects.equals(pq, that.pq)) return false;
return Objects.equals(compressedVectors, that.compressedVectors);
if (this.count() != that.count()) return false;
// TODO how do we want to determine equality? With the current change, we are willing to write one
// thing and materialize another. It seems like the real concern should be whether the compressedVectors have
// the same data, not whether they are in a MemorySegment or a byte[] and not whether there is one chunk of many
// vectors or many chunks of one vector. This technically goes against the implementation of each of the
// ByteSequence#equals methods, which raises the question of whether this is valid. I primarily updated this
// code to get testSaveLoadPQ to pass.
for (int i = 0; i < this.count(); i++) {
var thisNode = this.get(i);
var thatNode = that.get(i);
if (thisNode.length() != thatNode.length()) return false;
for (int j = 0; j < thisNode.length(); j++) {
if (thisNode.get(j) != thatNode.get(j)) return false;
}
}
return true;
}

@Override
public int hashCode() {
return Objects.hash(pq, compressedVectors);
// We don't use the array structure in the hash code calculation because we allow for different chunking
// strategies. Instead, we use the first entry in the first chunk to provide a stable hash code.
return Objects.hash(pq, count(), count() > 0 ? get(0).get(0) : 0);
}

@Override
Expand Down Expand Up @@ -188,7 +220,10 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
}

public ByteSequence<?> get(int ordinal) {
return compressedVectors.get(ordinal);
int chunkIndex = ordinal / vectorsPerChunk;
int vectorIndexInChunk = ordinal % vectorsPerChunk;
int start = vectorIndexInChunk * pq.getSubspaceCount();
return compressedDataChunks[chunkIndex].slice(start, pq.getSubspaceCount());
}

public ProductQuantization getProductQuantization() {
Expand Down Expand Up @@ -225,16 +260,19 @@ public long ramBytesUsed() {
int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;

long codebooksSize = pq.ramBytesUsed();
long listSize = (long) REF_BYTES * (1 + compressedVectors.size());
long dataSize = (long) (OH_BYTES + AH_BYTES + pq.compressedVectorSize()) * compressedVectors.size();
return codebooksSize + listSize + dataSize;
long chunksArraySize = OH_BYTES + AH_BYTES + (long) compressedDataChunks.length * REF_BYTES;
long dataSize = 0;
for (ByteSequence<?> chunk : compressedDataChunks) {
dataSize += chunk.ramBytesUsed();
}
return codebooksSize + chunksArraySize + dataSize;
}

@Override
public String toString() {
return "PQVectors{" +
"pq=" + pq +
", count=" + compressedVectors.size() +
", count=" + vectorCount +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public ProductQuantization refine(RandomAccessVectorValues ravv,

@Override
public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
return new PQVectors(this, (ByteSequence<?>[]) compressedVectors);
return new PQVectors(this, (ByteSequence<?>[]) compressedVectors, compressedVectors.length, 1);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ public byte[] get() {
return data;
}

@Override
public int offset() {
return 0;
}

@Override
public byte get(int n) {
return data[n];
Expand Down Expand Up @@ -72,6 +77,14 @@ public ArrayByteSequence copy() {
return new ArrayByteSequence(Arrays.copyOf(data, data.length));
}

@Override
public ByteSequence<byte[]> slice(int offset, int length) {
if (offset == 0 && length == data.length) {
return this;
}
return new ArraySliceByteSequence(data, offset, length);
}

@Override
public long ramBytesUsed() {
int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.jbellis.jvector.vector;

import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import java.util.Arrays;

/**
* A read only {@link ByteSequence} implementation that wraps an array and provides a view into a slice of it.
*/
public class ArraySliceByteSequence implements ByteSequence<byte[]> {
private final byte[] data;
private final int offset;
private final int length;

public ArraySliceByteSequence(byte[] data, int offset, int length) {
if (offset < 0 || length < 0 || offset + length > data.length) {
throw new IllegalArgumentException("Invalid offset or length");
}
this.data = data;
this.offset = offset;
this.length = length;
}

@Override
public byte[] get() {
return data;
}

@Override
public int offset() {
return offset;
}

@Override
public byte get(int n) {
if (n < 0 || n >= length) {
throw new IndexOutOfBoundsException("Index: " + n + ", Length: " + length);
}
return data[offset + n];
}

@Override
public void set(int n, byte value) {
if (n < 0 || n >= length) {
throw new IndexOutOfBoundsException("Index: " + n + ", Length: " + length);
}
data[offset + n] = value;
}

@Override
public void setLittleEndianShort(int shortIndex, short value) {
throw new UnsupportedOperationException("Not supported on slices");
}

@Override
public void zero() {
throw new UnsupportedOperationException("Not supported on slices");
}

@Override
public int length() {
return length;
}

@Override
public ByteSequence<byte[]> copy() {
byte[] newData = Arrays.copyOfRange(data, offset, offset + length);
return new ArrayByteSequence(newData);
}

@Override
public ByteSequence<byte[]> slice(int sliceOffset, int sliceLength) {
if (sliceOffset < 0 || sliceLength < 0 || sliceOffset + sliceLength > length) {
throw new IllegalArgumentException("Invalid slice parameters");
}
if (sliceOffset == 0 && sliceLength == length) {
return this;
}
return new ArraySliceByteSequence(data, offset + sliceOffset, sliceLength);
}

@Override
public long ramBytesUsed() {
// Only count the overhead of this slice object, not the underlying array
// since that's shared and counted elsewhere
return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER +
(3 * Integer.BYTES); // offset, length, and reference to data
}

@Override
public void copyFrom(ByteSequence<?> src, int srcOffset, int destOffset, int copyLength) {
throw new UnsupportedOperationException("Not supported on slices");
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[");
for (int i = 0; i < Math.min(length, 25); i++) {
sb.append(get(i));
if (i < length - 1) {
sb.append(", ");
}
}
if (length > 25) {
sb.append("...");
}
sb.append("]");
return sb.toString();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ArraySliceByteSequence that = (ArraySliceByteSequence) o;
if (this.length != that.length) return false;
for (int i = 0; i < length; i++) {
if (this.get(i) != that.get(i)) return false;
}
return true;
}

@Override
public int hashCode() {
int result = 1;
for (int i = 0; i < length; i++) {
result = 31 * result + get(i);
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public interface ByteSequence<T> extends Accountable
*/
T get();

int offset();

int length();

byte get(int i);
Expand All @@ -42,4 +44,6 @@ public interface ByteSequence<T> extends Accountable
void copyFrom(ByteSequence<?> src, int srcOffset, int destOffset, int length);

ByteSequence<T> copy();

ByteSequence<T> slice(int offset, int length);
}
Loading

0 comments on commit 650b4d2

Please sign in to comment.