From 87d95b3de03fe5fd2efae4eec76a7c7920af17aa Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Tue, 1 Oct 2024 16:11:33 -0700 Subject: [PATCH] Allow build graph greedily for quantization scenarios Previosuly we only added support to build greedily for non quantization scenario. In this commit, we can remove that constraint, however, we cannot skip writing quanitization state since it is required irrespective of type of search is executed later. Signed-off-by: Vijayan Balasubramanian --- .../NativeEngines990KnnVectorsWriter.java | 10 +- ...eEngines990KnnVectorsWriterFlushTests.java | 206 +++++++++++++++++- 2 files changed, 211 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 1251b9763..b1d6865e1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -100,8 +100,9 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { field.getVectors() ); final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); - // Will consider building vector data structure based on threshold only for non quantization indices - if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { + // Check only after quantization state writer finish writing its state, since it is required + // even if there are no graph files in segment, which will be later used by exact search + if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { log.info( "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush", fieldInfo.name, @@ -139,8 +140,9 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState } final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); - // Will consider building vector data structure based on threshold only for non quantization indices - if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { + // Check only after quantization state writer finish writing its state, since it is required + // even if there are no graph files in segment, which will be later used by exact search + if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { log.info( "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge", fieldInfo.name, diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index 659c980ff..57366753e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -602,7 +602,6 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr verify(flatVectorsWriter).flush(5, null); if (vectorsPerField.size() > 0) { assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); - assertTrue((long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0); } IntStream.range(0, vectorsPerField.size()).forEach(i -> { try { @@ -618,6 +617,211 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr } } + public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNotMet_thenSkipBuildingGraph() + throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + final int maxThreshold = sizeMap.values().stream().filter(count -> count != 0).max(Integer::compareTo).orElse(0); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + maxThreshold + 1 // to avoid building graph using max doc threshold, the same can be achieved by -1 too + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + .thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + }); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState); + } else { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) + ); + } + } + + public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNegative_thenSkipBuildingGraph() + throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + BUILD_GRAPH_NEVER_THRESHOLD + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + .thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + }); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState); + } else { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) + ); + } + } + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { FieldInfo fieldInfo = mock(FieldInfo.class); when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber);