Skip to content

Commit

Permalink
Reenable SimdOps.assembleAndSum; implement Panama/Native equivalent f…
Browse files Browse the repository at this point in the history
…or CosineDecoder acceleration (#368)

Co-authored-by: Joel Knighton <[email protected]>
  • Loading branch information
michaeljmarshall and jkni authored Nov 22, 2024
1 parent 8f115d7 commit da08d40
Show file tree
Hide file tree
Showing 9 changed files with 218 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,18 +131,10 @@ public float similarityTo(int node2) {
}

protected float decodedCosine(int node2) {
float sum = 0.0f;
float aMag = 0.0f;

ByteSequence<?> encoded = cv.get(node2);

for (int m = 0; m < encoded.length(); ++m) {
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
sum += partialSums.get((m * cv.pq.getClusterCount()) + centroidIndex);
aMag += aMagnitude.get((m * cv.pq.getClusterCount()) + centroidIndex);
}

return (float) (sum / Math.sqrt(aMag * bMagnitude));
return VectorUtil.pqDecodedCosineSimilarity(encoded, cv.pq.getClusterCount(), partialSums, aMagnitude, bMagnitude);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,8 @@ public static float max(VectorFloat<?> v) {
public static float min(VectorFloat<?> v) {
return impl.min(v);
}

public static float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude) {
return impl.pqDecodedCosineSimilarity(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,18 @@ default void bulkShuffleQuantizedSimilarityCosine(ByteSequence<?> shuffles, int
float max(VectorFloat<?> v);
float min(VectorFloat<?> v);

default float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
float sum = 0.0f;
float aMag = 0.0f;

for (int m = 0; m < encoded.length(); ++m) {
int centroidIndex = Byte.toUnsignedInt(encoded.get(m));
var index = m * clusterCount + centroidIndex;
sum += partialSums.get(index);
aMag += aMagnitude.get(index);
}

return (float) (sum / Math.sqrt(aMag * bMagnitude));
}
}
43 changes: 43 additions & 0 deletions jvector-native/src/main/c/jvector_simd.c
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,49 @@ float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned c
return res;
}

float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude) {
__m512 sum = _mm512_setzero_ps();
__m512 vaMagnitude = _mm512_setzero_ps();
int i = 0;
int limit = baseOffsetsLength - (baseOffsetsLength % 16);
__m512i indexRegister = initialIndexRegister;
__m512i scale = _mm512_set1_epi32(clusterCount);


for (; i < limit; i += 16) {
// Load and convert baseOffsets to integers
__m128i baseOffsetsRaw = _mm_loadu_si128((__m128i *)(baseOffsets + i));
__m512i baseOffsetsInt = _mm512_cvtepu8_epi32(baseOffsetsRaw);

indexRegister = _mm512_add_epi32(indexRegister, indexIncrement);
// Scale the baseOffsets by the cluster count
__m512i scaledOffsets = _mm512_mullo_epi32(indexRegister, scale);

// Calculate the final convOffsets by adding the scaled indexes and the base offsets
__m512i convOffsets = _mm512_add_epi32(scaledOffsets, baseOffsetsInt);

// Gather and sum values for partial sums and a magnitude
__m512 partialSumVals = _mm512_i32gather_ps(convOffsets, partialSums, 4);
sum = _mm512_add_ps(sum, partialSumVals);

__m512 aMagnitudeVals = _mm512_i32gather_ps(convOffsets, aMagnitude, 4);
vaMagnitude = _mm512_add_ps(vaMagnitude, aMagnitudeVals);
}

// Reduce sums
float sumResult = _mm512_reduce_add_ps(sum);
float aMagnitudeResult = _mm512_reduce_add_ps(vaMagnitude);

// Handle the remaining elements
for (; i < baseOffsetsLength; i++) {
int offset = clusterCount * i + baseOffsets[i];
sumResult += partialSums[offset];
aMagnitudeResult += aMagnitude[offset];
}

return sumResult / sqrtf(aMagnitudeResult * bMagnitude);
}

void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookIndex, int size, int clusterCount, const float* query, int queryOffset, float* partialSums) {
int codebookBase = codebookIndex * clusterCount;
for (int i = 0; i < clusterCount; i++) {
Expand Down
1 change: 1 addition & 0 deletions jvector-native/src/main/c/jvector_simd.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ void bulk_quantized_shuffle_dot_f32_512(const unsigned char* shuffles, int codeb
void bulk_quantized_shuffle_euclidean_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartials, float delta, float minDistance, float* results);
void bulk_quantized_shuffle_cosine_f32_512(const unsigned char* shuffles, int codebookCount, const char* quantizedPartialSums, float sumDelta, float minDistance, const char* quantizedPartialMagnitudes, float magnitudeDelta, float minMagnitude, float queryMagnitudeSquared, float* results);
float assemble_and_sum_f32_512(const float* data, int dataBase, const unsigned char* baseOffsets, int baseOffsetsLength);
float pq_decoded_cosine_similarity_f32_512(const unsigned char* baseOffsets, int baseOffsetsLength, int clusterCount, const float* partialSums, const float* aMagnitude, float bMagnitude);
void calculate_partial_sums_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
void calculate_partial_sums_euclidean_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums);
void calculate_partial_sums_best_dot_f32_512(const float* codebook, int codebookBase, int size, int clusterCount, const float* query, int queryOffset, float* partialSums, float* partialBestDistances);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,4 +155,10 @@ public void bulkShuffleQuantizedSimilarityCosine(ByteSequence<?> shuffles, int c
NativeSimdOps.bulk_quantized_shuffle_cosine_f32_512(((MemorySegmentByteSequence) shuffles).get(), codebookCount, ((MemorySegmentByteSequence) quantizedPartialSums).get(), sumDelta, minDistance,
((MemorySegmentByteSequence) quantizedPartialSquaredMagnitudes).get(), magnitudeDelta, minMagnitude, queryMagnitudeSquared, ((MemorySegmentVectorFloat) results).get());
}

@Override
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
return NativeSimdOps.pq_decoded_cosine_similarity_f32_512(((MemorySegmentByteSequence) encoded).get(), encoded.length(), clusterCount, ((MemorySegmentVectorFloat) partialSums).get(), ((MemorySegmentVectorFloat) aMagnitude).get(), bMagnitude);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,58 @@ public static float assemble_and_sum_f32_512(MemorySegment data, int dataBase, M
}
}

private static class pq_decoded_cosine_similarity_f32_512 {
public static final FunctionDescriptor DESC = FunctionDescriptor.of(
NativeSimdOps.C_FLOAT,
NativeSimdOps.C_POINTER,
NativeSimdOps.C_INT,
NativeSimdOps.C_INT,
NativeSimdOps.C_POINTER,
NativeSimdOps.C_POINTER,
NativeSimdOps.C_FLOAT
);

public static final MethodHandle HANDLE = Linker.nativeLinker().downcallHandle(
NativeSimdOps.findOrThrow("pq_decoded_cosine_similarity_f32_512"),
DESC, Linker.Option.critical(true));
}

/**
* Function descriptor for:
* {@snippet lang=c :
* float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
* }
*/
public static FunctionDescriptor pq_decoded_cosine_similarity_f32_512$descriptor() {
return pq_decoded_cosine_similarity_f32_512.DESC;
}

/**
* Downcall method handle for:
* {@snippet lang=c :
* float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
* }
*/
public static MethodHandle pq_decoded_cosine_similarity_f32_512$handle() {
return pq_decoded_cosine_similarity_f32_512.HANDLE;
}
/**
* {@snippet lang=c :
* float pq_decoded_cosine_similarity_f32_512(const unsigned char *baseOffsets, int baseOffsetsLength, int clusterCount, const float *partialSums, const float *aMagnitude, float bMagnitude)
* }
*/
public static float pq_decoded_cosine_similarity_f32_512(MemorySegment baseOffsets, int baseOffsetsLength, int clusterCount, MemorySegment partialSums, MemorySegment aMagnitude, float bMagnitude) {
var mh$ = pq_decoded_cosine_similarity_f32_512.HANDLE;
try {
if (TRACE_DOWNCALLS) {
traceDowncall("pq_decoded_cosine_similarity_f32_512", baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude);
}
return (float)mh$.invokeExact(baseOffsets, baseOffsetsLength, clusterCount, partialSums, aMagnitude, bMagnitude);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}

private static class calculate_partial_sums_dot_f32_512 {
public static final FunctionDescriptor DESC = FunctionDescriptor.ofVoid(
NativeSimdOps.C_POINTER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,7 @@ public VectorFloat<?> sub(VectorFloat<?> a, int aOffset, VectorFloat<?> b, int b

@Override
public float assembleAndSum(VectorFloat<?> data, int dataBase, ByteSequence<?> baseOffsets) {
float sum = 0f;
for (int i = 0; i < baseOffsets.length(); i++) {
sum += data.get(dataBase * i + Byte.toUnsignedInt(baseOffsets.get(i)));
}
return sum;
return SimdOps.assembleAndSum(((ArrayVectorFloat) data).get(), dataBase, ((ArrayByteSequence) baseOffsets).get());
}

@Override
Expand Down Expand Up @@ -159,5 +155,11 @@ public void calculatePartialSums(VectorFloat<?> codebook, int codebookIndex, int
public void quantizePartials(float delta, VectorFloat<?> partials, VectorFloat<?> partialBases, ByteSequence<?> quantizedPartials) {
SimdOps.quantizePartials(delta, (ArrayVectorFloat) partials, (ArrayVectorFloat) partialBases, (ArrayByteSequence) quantizedPartials);
}

@Override
public float pqDecodedCosineSimilarity(ByteSequence<?> encoded, int clusterCount, VectorFloat<?> partialSums, VectorFloat<?> aMagnitude, float bMagnitude)
{
return SimdOps.pqDecodedCosineSimilarity((ArrayByteSequence) encoded, clusterCount, (ArrayVectorFloat) partialSums, (ArrayVectorFloat) aMagnitude, bMagnitude);
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -525,18 +525,18 @@ static float assembleAndSum512(float[] data, int dataBase, byte[] baseOffsets) {
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_512);
int i = 0;
int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length);
var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(dataBase);

for (; i < limit; i += ByteVector.SPECIES_128.length()) {
var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(1).add(i).mul(dataBase);

ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i)
.convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0)
.lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512)
.reinterpretAsInts()
.add(scale)
.intoArray(convOffsets,0);

sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, data, 0, convOffsets, 0));
var offset = i * dataBase;
sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, data, offset, convOffsets, 0));
}

float res = sum.reduceLanes(VectorOperators.ADD);
Expand All @@ -553,9 +553,9 @@ static float assembleAndSum256(float[] data, int dataBase, byte[] baseOffsets) {
FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256);
int i = 0;
int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length);
var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(dataBase);

for (; i < limit; i += ByteVector.SPECIES_64.length()) {
var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(1).add(i).mul(dataBase);

ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i)
.convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0)
Expand All @@ -564,7 +564,8 @@ static float assembleAndSum256(float[] data, int dataBase, byte[] baseOffsets) {
.add(scale)
.intoArray(convOffsets,0);

sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, data, 0, convOffsets, 0));
var offset = i * dataBase;
sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, data, offset, convOffsets, 0));
}

float res = sum.reduceLanes(VectorOperators.ADD);
Expand Down Expand Up @@ -658,4 +659,88 @@ public static void quantizePartials(float delta, ArrayVectorFloat partials, Arra
}
}
}

public static float pqDecodedCosineSimilarity(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
return HAS_AVX512
? pqDecodedCosineSimilarity512(encoded, clusterCount, partialSums, aMagnitude, bMagnitude)
: pqDecodedCosineSimilarity256(encoded, clusterCount, partialSums, aMagnitude, bMagnitude);
}

public static float pqDecodedCosineSimilarity512(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
var sum = FloatVector.zero(FloatVector.SPECIES_512);
var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_512);
var baseOffsets = encoded.get();
var partialSumsArray = partialSums.get();
var aMagnitudeArray = aMagnitude.get();

int[] convOffsets = scratchInt512.get();
int i = 0;
int limit = ByteVector.SPECIES_128.loopBound(baseOffsets.length);

var scale = IntVector.zero(IntVector.SPECIES_512).addIndex(clusterCount);

for (; i < limit; i += ByteVector.SPECIES_128.length()) {

ByteVector.fromArray(ByteVector.SPECIES_128, baseOffsets, i)
.convertShape(VectorOperators.B2I, IntVector.SPECIES_512, 0)
.lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_512)
.reinterpretAsInts()
.add(scale)
.intoArray(convOffsets,0);

var offset = i * clusterCount;
sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_512, partialSumsArray, offset, convOffsets, 0));
vaMagnitude = vaMagnitude.add(FloatVector.fromArray(FloatVector.SPECIES_512, aMagnitudeArray, offset, convOffsets, 0));
}

float sumResult = sum.reduceLanes(VectorOperators.ADD);
float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD);

for (; i < baseOffsets.length; i++) {
int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]);
sumResult += partialSumsArray[offset];
aMagnitudeResult += aMagnitudeArray[offset];
}

return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude));
}

public static float pqDecodedCosineSimilarity256(ArrayByteSequence encoded, int clusterCount, ArrayVectorFloat partialSums, ArrayVectorFloat aMagnitude, float bMagnitude) {
var sum = FloatVector.zero(FloatVector.SPECIES_256);
var vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256);
var baseOffsets = encoded.get();
var partialSumsArray = partialSums.get();
var aMagnitudeArray = aMagnitude.get();

int[] convOffsets = scratchInt256.get();
int i = 0;
int limit = ByteVector.SPECIES_64.loopBound(baseOffsets.length);

var scale = IntVector.zero(IntVector.SPECIES_256).addIndex(clusterCount);

for (; i < limit; i += ByteVector.SPECIES_64.length()) {

ByteVector.fromArray(ByteVector.SPECIES_64, baseOffsets, i)
.convertShape(VectorOperators.B2I, IntVector.SPECIES_256, 0)
.lanewise(VectorOperators.AND, BYTE_TO_INT_MASK_256)
.reinterpretAsInts()
.add(scale)
.intoArray(convOffsets,0);

var offset = i * clusterCount;
sum = sum.add(FloatVector.fromArray(FloatVector.SPECIES_256, partialSumsArray, offset, convOffsets, 0));
vaMagnitude = vaMagnitude.add(FloatVector.fromArray(FloatVector.SPECIES_256, aMagnitudeArray, offset, convOffsets, 0));
}

float sumResult = sum.reduceLanes(VectorOperators.ADD);
float aMagnitudeResult = vaMagnitude.reduceLanes(VectorOperators.ADD);

for (; i < baseOffsets.length; i++) {
int offset = clusterCount * i + Byte.toUnsignedInt(baseOffsets[i]);
sumResult += partialSumsArray[offset];
aMagnitudeResult += aMagnitudeArray[offset];
}

return (float) (sumResult / Math.sqrt(aMagnitudeResult * bMagnitude));
}
}

0 comments on commit da08d40

Please sign in to comment.