Skip to content

Commit

Permalink
Store compressed vectors in dense ByteSequence for PQVectors
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeljmarshall committed Nov 23, 2024
1 parent da08d40 commit 4f1a44b
Show file tree
Hide file tree
Showing 10 changed files with 282 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,34 @@

import java.io.DataOutput;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Arrays;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicReference;

public class PQVectors implements CompressedVectors {
private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
private static final int MAX_CHUNK_SIZE = Integer.MAX_VALUE - 16; // standard Java array size limit with some headroom
final ProductQuantization pq;
private final List<ByteSequence<?>> compressedVectors;
private final ByteSequence<?>[] compressedDataChunks;
private final int vectorCount;
private final int vectorsPerChunk;

/**
* Initialize the PQVectors with an initial List of vectors. This list may be
* mutated, but caller is responsible for thread safety issues when doing so.
*/
public PQVectors(ProductQuantization pq, List<ByteSequence<?>> compressedVectors)
public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedDataChunks)
{
this.pq = pq;
this.compressedVectors = compressedVectors;
this(pq, compressedDataChunks, compressedDataChunks.length, 1);
}

public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedVectors)
public PQVectors(ProductQuantization pq, ByteSequence<?>[] compressedDataChunks, int vectorCount, int vectorsPerChunk)
{
this(pq, List.of(compressedVectors));
this.pq = pq;
this.compressedDataChunks = compressedDataChunks;
this.vectorCount = vectorCount;
this.vectorsPerChunk = vectorsPerChunk;
}

@Override
public int count() {
return compressedVectors.size();
return vectorCount;
}

@Override
Expand All @@ -65,10 +65,10 @@ public void write(DataOutput out, int version) throws IOException
pq.write(out, version);

// compressed vectors
out.writeInt(compressedVectors.size());
out.writeInt(vectorCount);
out.writeInt(pq.getSubspaceCount());
for (var v : compressedVectors) {
vectorTypeSupport.writeByteSequence(out, v);
for (ByteSequence<?> chunk : compressedDataChunks) {
vectorTypeSupport.writeByteSequence(out, chunk);
}
}

Expand All @@ -77,24 +77,33 @@ public static PQVectors load(RandomAccessReader in) throws IOException {
var pq = ProductQuantization.load(in);

// read the vectors
int size = in.readInt();
if (size < 0) {
throw new IOException("Invalid compressed vector count " + size);
int vectorCount = in.readInt();
if (vectorCount < 0) {
throw new IOException("Invalid compressed vector count " + vectorCount);
}
List<ByteSequence<?>> compressedVectors = new ArrayList<>(size);

int compressedDimension = in.readInt();
if (compressedDimension < 0) {
throw new IOException("Invalid compressed vector dimension " + compressedDimension);
}

for (int i = 0; i < size; i++)
{
ByteSequence<?> vector = vectorTypeSupport.readByteSequence(in, compressedDimension);
compressedVectors.add(vector);
// Calculate if we need to split into multiple chunks
long totalSize = (long) vectorCount * compressedDimension;
int vectorsPerChunk = totalSize <= MAX_CHUNK_SIZE ? vectorCount : MAX_CHUNK_SIZE / compressedDimension;

int numChunks = vectorCount / vectorsPerChunk;
ByteSequence<?>[] chunks = new ByteSequence<?>[numChunks];

for (int i = 0; i < numChunks - 1; i++) {
int chunkSize = vectorsPerChunk * compressedDimension;
chunks[i] = vectorTypeSupport.readByteSequence(in, chunkSize);
}

return new PQVectors(pq, compressedVectors);
// Last chunk might be smaller
int remainingVectors = vectorCount - (vectorsPerChunk * (numChunks - 1));
chunks[numChunks - 1] = vectorTypeSupport.readByteSequence(in, remainingVectors * compressedDimension);

return new PQVectors(pq, chunks, vectorCount, vectorsPerChunk);
}

public static PQVectors load(RandomAccessReader in, long offset) throws IOException {
Expand All @@ -109,12 +118,17 @@ public boolean equals(Object o) {

PQVectors that = (PQVectors) o;
if (!Objects.equals(pq, that.pq)) return false;
return Objects.equals(compressedVectors, that.compressedVectors);
if (vectorsPerChunk != that.vectorsPerChunk) return false;
if (compressedDataChunks.length != that.compressedDataChunks.length) return false;
for (int i = 0; i < compressedDataChunks.length; i++) {
if (!compressedDataChunks[i].equals(that.compressedDataChunks[i])) return false;
}
return true;
}

@Override
public int hashCode() {
return Objects.hash(pq, compressedVectors);
return Objects.hash(pq, Arrays.hashCode(compressedDataChunks));
}

@Override
Expand Down Expand Up @@ -188,7 +202,10 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat<?> q,
}

public ByteSequence<?> get(int ordinal) {
return compressedVectors.get(ordinal);
int chunkIndex = ordinal / vectorsPerChunk;
int vectorIndexInChunk = ordinal % vectorsPerChunk;
int start = vectorIndexInChunk * pq.getSubspaceCount();
return compressedDataChunks[chunkIndex].slice(start, pq.getSubspaceCount());
}

public ProductQuantization getProductQuantization() {
Expand Down Expand Up @@ -225,16 +242,19 @@ public long ramBytesUsed() {
int AH_BYTES = RamUsageEstimator.NUM_BYTES_ARRAY_HEADER;

long codebooksSize = pq.ramBytesUsed();
long listSize = (long) REF_BYTES * (1 + compressedVectors.size());
long dataSize = (long) (OH_BYTES + AH_BYTES + pq.compressedVectorSize()) * compressedVectors.size();
return codebooksSize + listSize + dataSize;
long chunksArraySize = OH_BYTES + AH_BYTES + (long) compressedDataChunks.length * REF_BYTES;
long dataSize = 0;
for (ByteSequence<?> chunk : compressedDataChunks) {
dataSize += chunk.ramBytesUsed();
}
return codebooksSize + chunksArraySize + dataSize;
}

@Override
public String toString() {
return "PQVectors{" +
"pq=" + pq +
", count=" + compressedVectors.size() +
", count=" + vectorCount +
'}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ public ProductQuantization refine(RandomAccessVectorValues ravv,

@Override
public CompressedVectors createCompressedVectors(Object[] compressedVectors) {
return new PQVectors(this, (ByteSequence<?>[]) compressedVectors);
return new PQVectors(this, (ByteSequence<?>[]) compressedVectors, compressedVectors.length, 1);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ public byte[] get() {
return data;
}

@Override
public int offset() {
return 0;
}

@Override
public byte get(int n) {
return data[n];
Expand Down Expand Up @@ -72,6 +77,14 @@ public ArrayByteSequence copy() {
return new ArrayByteSequence(Arrays.copyOf(data, data.length));
}

@Override
public ByteSequence<byte[]> slice(int offset, int length) {
if (offset == 0 && length == data.length) {
return this;
}
return new ArrayByteSequenceSlice(data, offset, length);
}

@Override
public long ramBytesUsed() {
int OH_BYTES = RamUsageEstimator.NUM_BYTES_OBJECT_HEADER;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/*
* Copyright DataStax, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.github.jbellis.jvector.vector;

import io.github.jbellis.jvector.util.RamUsageEstimator;
import io.github.jbellis.jvector.vector.types.ByteSequence;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.nio.ByteOrder;

public class ArrayByteSequenceSlice implements ByteSequence<byte[]> {
private final byte[] data;
private final int offset;
private final int length;

public ArrayByteSequenceSlice(byte[] data, int offset, int length) {
if (offset < 0 || length < 0 || offset + length > data.length) {
throw new IllegalArgumentException("Invalid offset or length");
}
this.data = data;
this.offset = offset;
this.length = length;
}

@Override
public byte[] get() {
return data;
}

@Override
public int offset() {
return offset;
}

@Override
public byte get(int n) {
if (n < 0 || n >= length) {
throw new IndexOutOfBoundsException("Index: " + n + ", Length: " + length);
}
return data[offset + n];
}

@Override
public void set(int n, byte value) {
if (n < 0 || n >= length) {
throw new IndexOutOfBoundsException("Index: " + n + ", Length: " + length);
}
data[offset + n] = value;
}

@Override
public void setLittleEndianShort(int shortIndex, short value) {
int baseIndex = offset + (shortIndex * 2);
if (baseIndex < offset || baseIndex + 1 >= offset + length) {
throw new IndexOutOfBoundsException("Short index out of bounds: " + shortIndex);
}
data[baseIndex] = (byte) (value & 0xFF);
data[baseIndex + 1] = (byte) ((value >> 8) & 0xFF);
}

@Override
public void zero() {
Arrays.fill(data, offset, offset + length, (byte) 0);
}

@Override
public int length() {
return length;
}

@Override
public ArrayByteSequenceSlice copy() {
byte[] newData = Arrays.copyOfRange(data, offset, offset + length);
return new ArrayByteSequenceSlice(newData, 0, length);
}

@Override
public ByteSequence<byte[]> slice(int sliceOffset, int sliceLength) {
if (sliceOffset < 0 || sliceLength < 0 || sliceOffset + sliceLength > length) {
throw new IllegalArgumentException("Invalid slice parameters");
}
if (sliceOffset == 0 && sliceLength == length) {
return this;
}
return new ArrayByteSequenceSlice(data, offset + sliceOffset, sliceLength);
}

@Override
public long ramBytesUsed() {
// Only count the overhead of this slice object, not the underlying array
// since that's shared and counted elsewhere
return RamUsageEstimator.NUM_BYTES_OBJECT_HEADER +
(3 * Integer.BYTES); // offset, length, and reference to data
}

@Override
public void copyFrom(ByteSequence<?> src, int srcOffset, int destOffset, int copyLength) {
if (destOffset < 0 || destOffset + copyLength > length) {
throw new IndexOutOfBoundsException("Destination range out of bounds");
}
if (src instanceof ArrayByteSequence) {
ArrayByteSequence csrc = (ArrayByteSequence) src;
System.arraycopy(csrc.get(), srcOffset, data, this.offset + destOffset, copyLength);
} else if (src instanceof ArrayByteSequenceSlice) {
ArrayByteSequenceSlice csrc = (ArrayByteSequenceSlice) src;
System.arraycopy(csrc.data, csrc.offset + srcOffset, data, this.offset + destOffset, copyLength);
} else {
// Fallback for other implementations
for (int i = 0; i < copyLength; i++) {
set(destOffset + i, src.get(srcOffset + i));
}
}
}

@Override
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append("[");
for (int i = 0; i < Math.min(length, 25); i++) {
sb.append(get(i));
if (i < length - 1) {
sb.append(", ");
}
}
if (length > 25) {
sb.append("...");
}
sb.append("]");
return sb.toString();
}

@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
ArrayByteSequenceSlice that = (ArrayByteSequenceSlice) o;
if (this.length != that.length) return false;
for (int i = 0; i < length; i++) {
if (this.get(i) != that.get(i)) return false;
}
return true;
}

@Override
public int hashCode() {
int result = 1;
for (int i = 0; i < length; i++) {
result = 31 * result + get(i);
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ public interface ByteSequence<T> extends Accountable
*/
T get();

int offset();

int length();

byte get(int i);
Expand All @@ -42,4 +44,6 @@ public interface ByteSequence<T> extends Accountable
void copyFrom(ByteSequence<?> src, int srcOffset, int destOffset, int length);

ByteSequence<T> copy();

ByteSequence<T> slice(int offset, int length);
}
Loading

0 comments on commit 4f1a44b

Please sign in to comment.