Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Faiss efficient filter exact search using byte vector datatype #2165

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/ExactSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,15 @@
import org.opensearch.knn.common.FieldInfoExtractor;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.query.iterators.BinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.ByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.NestedBinaryVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.VectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.KNNIterator;
import org.opensearch.knn.index.query.iterators.NestedByteVectorIdsKNNIterator;
import org.opensearch.knn.index.query.iterators.NestedVectorIdsKNNIterator;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory;
Expand Down Expand Up @@ -111,21 +114,35 @@ private KNNIterator getKNNIterator(LeafReaderContext leafReaderContext, ExactSea
if (VectorDataType.BINARY == knnQuery.getVectorDataType()) {
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedByteVectorIdsKNNIterator(
return new NestedBinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}
return new ByteVectorIdsKNNIterator(
return new BinaryVectorIdsKNNIterator(
matchedDocs,
knnQuery.getByteQueryVector(),
(KNNBinaryVectorValues) vectorValues,
spaceType
);
}

if (VectorDataType.BYTE == knnQuery.getVectorDataType()) {
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
final KNNVectorValues<byte[]> vectorValues = KNNVectorValuesFactory.getVectorValues(fieldInfo, reader);
if (isNestedRequired) {
return new NestedByteVectorIdsKNNIterator(
matchedDocs,
knnQuery.getQueryVector(),
(KNNByteVectorValues) vectorValues,
spaceType,
knnQuery.getParentsFilter().getBitSet(leafReaderContext)
);
}
return new ByteVectorIdsKNNIterator(matchedDocs, knnQuery.getQueryVector(), (KNNByteVectorValues) vectorValues, spaceType);
}
final byte[] quantizedQueryVector;
final SegmentLevelQuantizationInfo segmentLevelQuantizationInfo;
if (exactSearcherContext.isUseQuantizedVectorsForSearch()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.iterators;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.apache.lucene.util.BitSetIterator;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

import java.io.IOException;

/**
* Inspired by DiversifyingChildrenFloatKnnVectorQuery in lucene
* https://github.com/apache/lucene/blob/7b8aece125aabff2823626d5b939abf4747f63a7/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java#L162
*
* The class is used in KNNWeight to score all docs, but, it iterates over filterIdsArray if filter is provided
*/
public class BinaryVectorIdsKNNIterator implements KNNIterator {
protected final BitSetIterator bitSetIterator;
protected final byte[] queryVector;
protected final KNNBinaryVectorValues binaryVectorValues;
protected final SpaceType spaceType;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected int docId;

public BinaryVectorIdsKNNIterator(
@Nullable final BitSet filterIdsBitSet,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType
) throws IOException {
this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length());
this.queryVector = queryVector;
this.binaryVectorValues = binaryVectorValues;
this.spaceType = spaceType;
// This cannot be moved inside nextDoc() method since it will break when we have nested field, where
// nextDoc should already be referring to next knnVectorValues
this.docId = getNextDocId();
}

public BinaryVectorIdsKNNIterator(final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType)
throws IOException {
this(null, queryVector, binaryVectorValues, spaceType);
}

/**
* Advance to the next doc and update score value with score of the next doc.
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs
*
* @return next doc id
*/
@Override
public int nextDoc() throws IOException {

if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}
currentScore = computeScore();
int currentDocId = docId;
docId = getNextDocId();
return currentDocId;
}

@Override
public float score() {
return currentScore;
}

protected float computeScore() throws IOException {
final byte[] vector = binaryVectorValues.getVector();
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);
}

protected int getNextDocId() throws IOException {
if (bitSetIterator == null) {
return binaryVectorValues.nextDoc();
}
int nextDocID = this.bitSetIterator.nextDoc();
// For filter case, advance vector values to corresponding doc id from filter bit set
if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) {
binaryVectorValues.advance(nextDocID);
}
return nextDocID;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import org.apache.lucene.util.BitSetIterator;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;

import java.io.IOException;

Expand All @@ -22,30 +22,30 @@
*/
public class ByteVectorIdsKNNIterator implements KNNIterator {
protected final BitSetIterator bitSetIterator;
protected final byte[] queryVector;
protected final KNNBinaryVectorValues binaryVectorValues;
protected final float[] queryVector;
protected final KNNByteVectorValues byteVectorValues;
protected final SpaceType spaceType;
protected float currentScore = Float.NEGATIVE_INFINITY;
protected int docId;

public ByteVectorIdsKNNIterator(
@Nullable final BitSet filterIdsBitSet,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final float[] queryVector,
final KNNByteVectorValues byteVectorValues,
final SpaceType spaceType
) throws IOException {
this.bitSetIterator = filterIdsBitSet == null ? null : new BitSetIterator(filterIdsBitSet, filterIdsBitSet.length());
this.queryVector = queryVector;
this.binaryVectorValues = binaryVectorValues;
this.byteVectorValues = byteVectorValues;
this.spaceType = spaceType;
// This cannot be moved inside nextDoc() method since it will break when we have nested field, where
// nextDoc should already be referring to next knnVectorValues
this.docId = getNextDocId();
}

public ByteVectorIdsKNNIterator(final byte[] queryVector, final KNNBinaryVectorValues binaryVectorValues, final SpaceType spaceType)
public ByteVectorIdsKNNIterator(final float[] queryVector, final KNNByteVectorValues byteVectorValues, final SpaceType spaceType)
throws IOException {
this(null, queryVector, binaryVectorValues, spaceType);
this(null, queryVector, byteVectorValues, spaceType);
}

/**
Expand All @@ -72,20 +72,28 @@ public float score() {
}

protected float computeScore() throws IOException {
final byte[] vector = binaryVectorValues.getVector();
final byte[] vector = byteVectorValues.getVector();
// Calculates a similarity score between the two vectors with a specified function. Higher similarity
// scores correspond to closer vectors.
return spaceType.getKnnVectorSimilarityFunction().compare(queryVector, vector);

// The query vector of Faiss byte vector is a Float array because ScalarQuantizer accepts it as float array.
// Now, to compute the score between this query vector and each vector in KNNByteVectorValues we are casting this query vector into
// byte array.
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
final byte[] byteQueryVector = new byte[queryVector.length];
for (int i = 0; i < queryVector.length; i++) {
byteQueryVector[i] = (byte) queryVector[i];
}
return spaceType.getKnnVectorSimilarityFunction().compare(byteQueryVector, vector);
}

protected int getNextDocId() throws IOException {
if (bitSetIterator == null) {
return binaryVectorValues.nextDoc();
return byteVectorValues.nextDoc();
}
int nextDocID = this.bitSetIterator.nextDoc();
// For filter case, advance vector values to corresponding doc id from filter bit set
if (nextDocID != DocIdSetIterator.NO_MORE_DOCS) {
binaryVectorValues.advance(nextDocID);
byteVectorValues.advance(nextDocID);
}
return nextDocID;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.query.iterators;

import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.BitSet;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;

import java.io.IOException;

/**
* This iterator iterates filterIdsArray to scoreif filter is provided else it iterates over all docs.
* However, it dedupe docs per each parent doc
* of which ID is set in parentBitSet and only return best child doc with the highest score.
*/
public class NestedBinaryVectorIdsKNNIterator extends BinaryVectorIdsKNNIterator {
private final BitSet parentBitSet;

public NestedBinaryVectorIdsKNNIterator(
@Nullable final BitSet filterIdsArray,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) throws IOException {
super(filterIdsArray, queryVector, binaryVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

public NestedBinaryVectorIdsKNNIterator(
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) throws IOException {
super(null, queryVector, binaryVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

/**
* Advance to the next best child doc per parent and update score with the best score among child docs from the parent.
* DocIdSetIterator.NO_MORE_DOCS is returned when there is no more docs
*
* @return next best child doc id
*/
@Override
public int nextDoc() throws IOException {
if (docId == DocIdSetIterator.NO_MORE_DOCS) {
return DocIdSetIterator.NO_MORE_DOCS;
}

currentScore = Float.NEGATIVE_INFINITY;
int currentParent = parentBitSet.nextSetBit(docId);
int bestChild = -1;

// In order to traverse all children for given parent, we have to use docId < parentId, because,
// kNNVectorValues will not have parent id since DocId is unique per segment. For ex: let's say for doc id 1, there is one child
// and for doc id 5, there are three children. In that case knnVectorValues iterator will have [0, 2, 3, 4]
// and parentBitSet will have [1,5]
// Hence, we have to iterate till docId from knnVectorValues is less than parentId instead of till equal to parentId
while (docId != DocIdSetIterator.NO_MORE_DOCS && docId < currentParent) {
float score = computeScore();
if (score > currentScore) {
bestChild = docId;
currentScore = score;
}
docId = getNextDocId();
}

return bestChild;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import org.apache.lucene.util.BitSet;
import org.opensearch.common.Nullable;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.vectorvalues.KNNBinaryVectorValues;
import org.opensearch.knn.index.vectorvalues.KNNByteVectorValues;

import java.io.IOException;

Expand All @@ -23,18 +23,18 @@ public class NestedByteVectorIdsKNNIterator extends ByteVectorIdsKNNIterator {

public NestedByteVectorIdsKNNIterator(
@Nullable final BitSet filterIdsArray,
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final float[] queryVector,
final KNNByteVectorValues byteVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) throws IOException {
super(filterIdsArray, queryVector, binaryVectorValues, spaceType);
super(filterIdsArray, queryVector, byteVectorValues, spaceType);
this.parentBitSet = parentBitSet;
}

public NestedByteVectorIdsKNNIterator(
final byte[] queryVector,
final KNNBinaryVectorValues binaryVectorValues,
final float[] queryVector,
final KNNByteVectorValues binaryVectorValues,
final SpaceType spaceType,
final BitSet parentBitSet
) throws IOException {
Expand Down
Loading
Loading