Skip to content

Commit

Permalink
Allow build graph greedily for quantization scenarios
Browse files Browse the repository at this point in the history
Previosuly we only added support to build greedily for
non quantization scenario. In this commit, we can remove
that constraint, however, we cannot skip writing quanitization
state since it is required irrespective of type of search
is executed later.

Signed-off-by: Vijayan Balasubramanian <[email protected]>
  • Loading branch information
VijayanB committed Oct 1, 2024
1 parent d61e7d4 commit 87d95b3
Show file tree
Hide file tree
Showing 2 changed files with 211 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
// Will consider building vector data structure based on threshold only for non quantization indices
if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
// Check only after quantization state writer finish writing its state, since it is required
// even if there are no graph files in segment, which will be later used by exact search
if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
log.info(
"Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush",
fieldInfo.name,
Expand Down Expand Up @@ -139,8 +140,9 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState
}

final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs);
// Will consider building vector data structure based on threshold only for non quantization indices
if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
// Check only after quantization state writer finish writing its state, since it is required
// even if there are no graph files in segment, which will be later used by exact search
if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) {
log.info(
"Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge",
fieldInfo.name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -602,7 +602,6 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr
verify(flatVectorsWriter).flush(5, null);
if (vectorsPerField.size() > 0) {
assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size());
assertTrue((long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0);
}
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
try {
Expand All @@ -618,6 +617,211 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr
}
}

public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNotMet_thenSkipBuildingGraph()
throws IOException {
// Given
List<KNNVectorValues<float[]>> expectedVectorValues = new ArrayList<>();
final Map<Integer, Integer> sizeMap = new HashMap<>();
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(
new ArrayList<>(vectorsPerField.get(i).values())
);
final KNNVectorValues<float[]> knnVectorValues = KNNVectorValuesFactory.getVectorValues(
VectorDataType.FLOAT,
randomVectorValues
);
sizeMap.put(i, randomVectorValues.size());
expectedVectorValues.add(knnVectorValues);

});
final int maxThreshold = sizeMap.values().stream().filter(count -> count != 0).max(Integer::compareTo).orElse(0);
final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter(
segmentWriteState,
flatVectorsWriter,
maxThreshold + 1 // to avoid building graph using max doc threshold, the same can be achieved by -1 too
);

try (
MockedStatic<NativeEngineFieldVectorsWriter> fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class);
MockedStatic<KNNVectorValuesFactory> knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class);
MockedStatic<QuantizationService> quantizationServiceMockedStatic = mockStatic(QuantizationService.class);
MockedStatic<NativeIndexWriter> nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class);
MockedConstruction<KNN990QuantizationStateWriter> knn990QuantWriterMockedConstruction = mockConstruction(
KNN990QuantizationStateWriter.class
);
) {
quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService);

IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final FieldInfo fieldInfo = fieldInfo(
i,
VectorEncoding.FLOAT32,
Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss")
);

NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i));
fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream))
.thenReturn(field);

try {
nativeEngineWriter.addField(fieldInfo);
} catch (Exception e) {
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
try {
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
.thenReturn(quantizationState);
} catch (Exception e) {
throw new RuntimeException(e);
}

nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState))
.thenReturn(nativeIndexWriter);
});
doAnswer(answer -> {
Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion
return null;
}).when(nativeIndexWriter).flushIndex(any(), anyInt());

// When
nativeEngineWriter.flush(5, null);

// Then
verify(flatVectorsWriter).flush(5, null);
if (vectorsPerField.size() > 0) {
verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState);
} else {
assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size());
}
verifyNoInteractions(nativeIndexWriter);
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
try {
if (vectorsPerField.get(i).isEmpty()) {
verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState);
} else {
verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
});
final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count();
knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()),
times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled))
);
}
}

public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNegative_thenSkipBuildingGraph()
throws IOException {
// Given
List<KNNVectorValues<float[]>> expectedVectorValues = new ArrayList<>();
final Map<Integer, Integer> sizeMap = new HashMap<>();
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues(
new ArrayList<>(vectorsPerField.get(i).values())
);
final KNNVectorValues<float[]> knnVectorValues = KNNVectorValuesFactory.getVectorValues(
VectorDataType.FLOAT,
randomVectorValues
);
sizeMap.put(i, randomVectorValues.size());
expectedVectorValues.add(knnVectorValues);

});
final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter(
segmentWriteState,
flatVectorsWriter,
BUILD_GRAPH_NEVER_THRESHOLD
);

try (
MockedStatic<NativeEngineFieldVectorsWriter> fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class);
MockedStatic<KNNVectorValuesFactory> knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class);
MockedStatic<QuantizationService> quantizationServiceMockedStatic = mockStatic(QuantizationService.class);
MockedStatic<NativeIndexWriter> nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class);
MockedConstruction<KNN990QuantizationStateWriter> knn990QuantWriterMockedConstruction = mockConstruction(
KNN990QuantizationStateWriter.class
);
) {
quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService);

IntStream.range(0, vectorsPerField.size()).forEach(i -> {
final FieldInfo fieldInfo = fieldInfo(
i,
VectorEncoding.FLOAT32,
Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss")
);

NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i));
fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream))
.thenReturn(field);

try {
nativeEngineWriter.addField(fieldInfo);
} catch (Exception e) {
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));

when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams);
try {
when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size()))
.thenReturn(quantizationState);
} catch (Exception e) {
throw new RuntimeException(e);
}

nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState))
.thenReturn(nativeIndexWriter);
});
doAnswer(answer -> {
Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion
return null;
}).when(nativeIndexWriter).flushIndex(any(), anyInt());

// When
nativeEngineWriter.flush(5, null);

// Then
verify(flatVectorsWriter).flush(5, null);
if (vectorsPerField.size() > 0) {
verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState);
} else {
assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size());
}
verifyNoInteractions(nativeIndexWriter);
IntStream.range(0, vectorsPerField.size()).forEach(i -> {
try {
if (vectorsPerField.get(i).isEmpty()) {
verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState);
} else {
verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState);
}
} catch (Exception e) {
throw new RuntimeException(e);
}
});
final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count();
knnVectorValuesFactoryMockedStatic.verify(
() -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()),
times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled))
);
}
}

private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map<String, String> attributes) {
FieldInfo fieldInfo = mock(FieldInfo.class);
when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber);
Expand Down

0 comments on commit 87d95b3

Please sign in to comment.