Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/139409.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 139409
summary: Use new bulk scoring dot product for max inner product
area: Vector Search
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,37 @@ float scoreFromSegments(MemorySegment a, float aOffset, MemorySegment b, float b
return adjustedDistance + 1;
}

@Override
protected void bulkScoreFromSegment(
MemorySegment vectors,
int vectorLength,
int vectorPitch,
int firstOrd,
MemorySegment ordinals,
MemorySegment scores,
int numNodes
) {
long firstByteOffset = (long) firstOrd * vectorPitch;
var firstVector = vectors.asSlice(firstByteOffset, vectorPitch);
Similarities.dotProduct7uBulkWithOffsets(vectors, firstVector, dims, vectorPitch, ordinals, numNodes, scores);

// Java-side adjustment
var aOffset = Float.intBitsToFloat(
vectors.asSlice(firstByteOffset + vectorLength, Float.BYTES).getAtIndex(ValueLayout.JAVA_INT_UNALIGNED, 0)
);
for (int i = 0; i < numNodes; ++i) {
var dotProduct = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i);
var secondOrd = ordinals.getAtIndex(ValueLayout.JAVA_INT, i);
long secondByteOffset = (long) secondOrd * vectorPitch;
var bOffset = Float.intBitsToFloat(
vectors.asSlice(secondByteOffset + vectorLength, Float.BYTES).getAtIndex(ValueLayout.JAVA_INT_UNALIGNED, 0)
);
float adjustedDistance = dotProduct * scoreCorrectionConstant + aOffset + bOffset;
adjustedDistance = adjustedDistance < 0 ? 1 / (1 + -1 * adjustedDistance) : adjustedDistance + 1;
scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, adjustedDistance);
}
}

@Override
public MaxInnerProductSupplier copy() {
return new MaxInnerProductSupplier(input.clone(), values, scoreCorrectionConstant);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,30 @@ public float score(int node) throws IOException {
}
return adjustedDistance + 1;
}

@Override
public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException {
MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length());
if (vectorsSeg == null) {
super.bulkScore(nodes, scores, numNodes);
} else {
var ordinalsSeg = MemorySegment.ofArray(nodes);
var scoresSeg = MemorySegment.ofArray(scores);

var vectorPitch = vectorByteSize + Float.BYTES;
dotProduct7uBulkWithOffsets(vectorsSeg, query, vectorByteSize, vectorPitch, ordinalsSeg, numNodes, scoresSeg);

for (int i = 0; i < numNodes; ++i) {
var dotProduct = scores[i];
var secondOrd = nodes[i];
long secondByteOffset = (long) secondOrd * vectorPitch;
var nodeCorrection = Float.intBitsToFloat(input.readInt(secondByteOffset + vectorByteSize));
float adjustedDistance = dotProduct * scoreCorrectionConstant + queryCorrection + nodeCorrection;
adjustedDistance = adjustedDistance < 0 ? 1 / (1 + -1 * adjustedDistance) : adjustedDistance + 1;
scores[i] = adjustedDistance;
}
}
}
}

static void checkDimensions(int queryLen, int fieldLen) {
Expand Down