From af96fed3094c7347c2bd81b161c664e3c4d524d6 Mon Sep 17 00:00:00 2001 From: Vijayan Balasubramanian Date: Tue, 1 Oct 2024 16:11:33 -0700 Subject: [PATCH] Allow build graph greedily for quantization scenarios Previosuly we only added support to build greedily for non quantization scenario. In this commit, we can remove that constraint, however, we cannot skip writing quanitization state since it is required irrespective of type of search is executed later. Signed-off-by: Vijayan Balasubramanian --- .../NativeEngines990KnnVectorsWriter.java | 10 +- .../org/opensearch/knn/index/FaissIT.java | 118 ++++++++++ ...eEngines990KnnVectorsWriterFlushTests.java | 206 +++++++++++++++++- 3 files changed, 329 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java index 1251b9763..b1d6865e1 100644 --- a/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriter.java @@ -100,8 +100,9 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException { field.getVectors() ); final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs); - // Will consider building vector data structure based on threshold only for non quantization indices - if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { + // Check only after quantization state writer finish writing its state, since it is required + // even if there are no graph files in segment, which will be later used by exact search + if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { log.info( "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during flush", fieldInfo.name, @@ -139,8 +140,9 @@ public void mergeOneField(final FieldInfo fieldInfo, final MergeState mergeState } final QuantizationState quantizationState = train(fieldInfo, knnVectorValuesSupplier, totalLiveDocs); - // Will consider building vector data structure based on threshold only for non quantization indices - if (quantizationState == null && shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { + // Check only after quantization state writer finish writing its state, since it is required + // even if there are no graph files in segment, which will be later used by exact search + if (shouldSkipBuildingVectorDataStructure(totalLiveDocs)) { log.info( "Skip building vector data structure for field: {}, as liveDoc: {} is less than the threshold {} during merge", fieldInfo.name, diff --git a/src/test/java/org/opensearch/knn/index/FaissIT.java b/src/test/java/org/opensearch/knn/index/FaissIT.java index 2df1d8a60..eec520a63 100644 --- a/src/test/java/org/opensearch/knn/index/FaissIT.java +++ b/src/test/java/org/opensearch/knn/index/FaissIT.java @@ -575,6 +575,124 @@ public void testHNSWSQFP16_whenIndexedAndQueried_thenSucceed() { validateGraphEviction(); } + @SneakyThrows + public void testHNSWSQFP16_whenGraphThresholdIsNegative_whenIndexed_thenSkipCreatingGraph() { + final String indexName = "test-index-hnsw-sqfp16"; + final String fieldName = "test-field-hnsw-sqfp16"; + final SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT }; + final Random random = new Random(); + final SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)]; + + final int dimension = 128; + final int numDocs = 100; + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_SQ) + .startObject(PARAMETERS) + .field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + final Map mappingMap = xContentBuilderToMap(builder); + final String mapping = builder.toString(); + final Settings knnIndexSettings = buildKNNIndexSettings(-1); + createKnnIndex(indexName, knnIndexSettings, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + indexTestData(indexName, fieldName, dimension, numDocs); + + final float[] queryVector = new float[dimension]; + Arrays.fill(queryVector, (float) numDocs); + + // Assert we have the right number of documents in the index + assertEquals(numDocs, getDocCount(indexName)); + // KNN Query should return empty result + final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1); + final List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); + assertEquals(0, results.size()); + + deleteKNNIndex(indexName); + validateGraphEviction(); + } + + @SneakyThrows + public void testHNSWSQFP16_whenGraphThresholdIsMetDuringMerge_thenCreateGraph() { + final String indexName = "test-index-hnsw-sqfp16"; + final String fieldName = "test-field-hnsw-sqfp16"; + final SpaceType[] spaceTypes = { SpaceType.L2, SpaceType.INNER_PRODUCT }; + final Random random = new Random(); + final SpaceType spaceType = spaceTypes[random.nextInt(spaceTypes.length)]; + final int dimension = 128; + final int numDocs = 100; + + // Create an index + final XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(fieldName) + .field("type", "knn_vector") + .field("dimension", dimension) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(METHOD_PARAMETER_SPACE_TYPE, spaceType.getValue()) + .field(KNN_ENGINE, KNNEngine.FAISS.getName()) + .startObject(PARAMETERS) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, ENCODER_SQ) + .startObject(PARAMETERS) + .field(FAISS_SQ_TYPE, FAISS_SQ_ENCODER_FP16) + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject() + .endObject(); + + final Map mappingMap = xContentBuilderToMap(builder); + final String mapping = builder.toString(); + final Settings knnIndexSettings = buildKNNIndexSettings(numDocs); + createKnnIndex(indexName, knnIndexSettings, mapping); + assertEquals(new TreeMap<>(mappingMap), new TreeMap<>(getIndexMappingAsMap(indexName))); + indexTestData(indexName, fieldName, dimension, numDocs); + + final float[] queryVector = new float[dimension]; + Arrays.fill(queryVector, (float) numDocs); + + // Assert we have the right number of documents in the index + assertEquals(numDocs, getDocCount(indexName)); + + // KNN Query should return empty result + final Response searchResponse = searchKNNIndex(indexName, buildSearchQuery(fieldName, 1, queryVector, null), 1); + final List results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName); + assertEquals(0, results.size()); + + // update index setting to build graph and do force merge + // update build vector data structure setting + forceMergeKnnIndex(indexName, 1); + + queryTestData(indexName, fieldName, dimension, numDocs); + + deleteKNNIndex(indexName); + validateGraphEviction(); + } + @SneakyThrows public void testIVFSQFP16_whenIndexedAndQueried_thenSucceed() { diff --git a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java index 659c980ff..57366753e 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNN990Codec/NativeEngines990KnnVectorsWriterFlushTests.java @@ -602,7 +602,6 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr verify(flatVectorsWriter).flush(5, null); if (vectorsPerField.size() > 0) { assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); - assertTrue((long) KNNGraphValue.REFRESH_TOTAL_TIME_IN_MILLIS.getValue() > 0); } IntStream.range(0, vectorsPerField.size()).forEach(i -> { try { @@ -618,6 +617,211 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr } } + public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNotMet_thenSkipBuildingGraph() + throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + final int maxThreshold = sizeMap.values().stream().filter(count -> count != 0).max(Integer::compareTo).orElse(0); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + maxThreshold + 1 // to avoid building graph using max doc threshold, the same can be achieved by -1 too + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + .thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + }); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState); + } else { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) + ); + } + } + + public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThresholdIsNegative_thenSkipBuildingGraph() + throws IOException { + // Given + List> expectedVectorValues = new ArrayList<>(); + final Map sizeMap = new HashMap<>(); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final TestVectorValues.PreDefinedFloatVectorValues randomVectorValues = new TestVectorValues.PreDefinedFloatVectorValues( + new ArrayList<>(vectorsPerField.get(i).values()) + ); + final KNNVectorValues knnVectorValues = KNNVectorValuesFactory.getVectorValues( + VectorDataType.FLOAT, + randomVectorValues + ); + sizeMap.put(i, randomVectorValues.size()); + expectedVectorValues.add(knnVectorValues); + + }); + final NativeEngines990KnnVectorsWriter nativeEngineWriter = new NativeEngines990KnnVectorsWriter( + segmentWriteState, + flatVectorsWriter, + BUILD_GRAPH_NEVER_THRESHOLD + ); + + try ( + MockedStatic fieldWriterMockedStatic = mockStatic(NativeEngineFieldVectorsWriter.class); + MockedStatic knnVectorValuesFactoryMockedStatic = mockStatic(KNNVectorValuesFactory.class); + MockedStatic quantizationServiceMockedStatic = mockStatic(QuantizationService.class); + MockedStatic nativeIndexWriterMockedStatic = mockStatic(NativeIndexWriter.class); + MockedConstruction knn990QuantWriterMockedConstruction = mockConstruction( + KNN990QuantizationStateWriter.class + ); + ) { + quantizationServiceMockedStatic.when(() -> QuantizationService.getInstance()).thenReturn(quantizationService); + + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + final FieldInfo fieldInfo = fieldInfo( + i, + VectorEncoding.FLOAT32, + Map.of(KNNConstants.VECTOR_DATA_TYPE_FIELD, "float", KNNConstants.KNN_ENGINE, "faiss") + ); + + NativeEngineFieldVectorsWriter field = nativeEngineFieldVectorsWriter(fieldInfo, vectorsPerField.get(i)); + fieldWriterMockedStatic.when(() -> NativeEngineFieldVectorsWriter.create(fieldInfo, segmentWriteState.infoStream)) + .thenReturn(field); + + try { + nativeEngineWriter.addField(fieldInfo); + } catch (Exception e) { + throw new RuntimeException(e); + } + + DocsWithFieldSet docsWithFieldSet = field.getDocsWithField(); + knnVectorValuesFactoryMockedStatic.when( + () -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i)) + ).thenReturn(expectedVectorValues.get(i)); + + when(quantizationService.getQuantizationParams(fieldInfo)).thenReturn(quantizationParams); + try { + when(quantizationService.train(quantizationParams, expectedVectorValues.get(i), vectorsPerField.get(i).size())) + .thenReturn(quantizationState); + } catch (Exception e) { + throw new RuntimeException(e); + } + + nativeIndexWriterMockedStatic.when(() -> NativeIndexWriter.getWriter(fieldInfo, segmentWriteState, quantizationState)) + .thenReturn(nativeIndexWriter); + }); + doAnswer(answer -> { + Thread.sleep(2); // Need this for KNNGraph value assertion, removing this will fail the assertion + return null; + }).when(nativeIndexWriter).flushIndex(any(), anyInt()); + + // When + nativeEngineWriter.flush(5, null); + + // Then + verify(flatVectorsWriter).flush(5, null); + if (vectorsPerField.size() > 0) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeHeader(segmentWriteState); + } else { + assertEquals(0, knn990QuantWriterMockedConstruction.constructed().size()); + } + verifyNoInteractions(nativeIndexWriter); + IntStream.range(0, vectorsPerField.size()).forEach(i -> { + try { + if (vectorsPerField.get(i).isEmpty()) { + verify(knn990QuantWriterMockedConstruction.constructed().get(0), never()).writeState(i, quantizationState); + } else { + verify(knn990QuantWriterMockedConstruction.constructed().get(0)).writeState(i, quantizationState); + } + } catch (Exception e) { + throw new RuntimeException(e); + } + }); + final Long expectedTimesGetVectorValuesIsCalled = vectorsPerField.stream().filter(Predicate.not(Map::isEmpty)).count(); + knnVectorValuesFactoryMockedStatic.verify( + () -> KNNVectorValuesFactory.getVectorValues(any(VectorDataType.class), any(DocsWithFieldSet.class), any()), + times(Math.toIntExact(expectedTimesGetVectorValuesIsCalled)) + ); + } + } + private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map attributes) { FieldInfo fieldInfo = mock(FieldInfo.class); when(fieldInfo.getFieldNumber()).thenReturn(fieldNumber);