Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport Main] Added more detailed error messages for KNN model training #2440

6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
- Added more detailed error messages for KNN model training (#2378)[https://github.com/opensearch-project/k-NN/pull/2378]
- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320]
- Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357]
- Make the build work for M series MacOS without manual code changes and local JAVA_HOME config (#2397)[https://github.com/opensearch-project/k-NN/pull/2397]
- Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (#2408)[https://github.com/opensearch-project/k-NN/pull/2408]
- Remove skip building graph check for quantization use case (#2430)[https://github.com/opensearch-project/k-NN/2430]
- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320]
- Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357]
- Add WithFieldName implementation to KNNQueryBuilder (#2398)[https://github.com/opensearch-project/k-NN/pull/2398]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ public void testIVFSQFP16_onUpgradeWhenIndexedAndQueried_thenSucceed() throws Ex

// Add training data
createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, DIMENSION);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, DIMENSION);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -279,7 +279,7 @@ public void testIVFSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16Ra

// Add training data
createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public class ModelIT extends AbstractRestartUpgradeTestCase {
private static final int DELAY_MILLI_SEC = 1000;
private static final int MIN_NUM_OF_MODELS = 2;
private static final int K = 5;
private static final int NUM_DOCS = 10;
private static final int NUM_DOCS = 1100;
private static final int NUM_DOCS_TEST_MODEL_INDEX = 100;
private static final int NUM_DOCS_TEST_MODEL_INDEX_DEFAULT = 100;
private static final int NUM_DOCS_TEST_MODEL_INDEX_FOR_NON_KNN_INDEX = 100;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;

/**
* Abstract class for KNN methods. This class provides the common functionality for all KNN methods.
Expand Down Expand Up @@ -108,6 +114,55 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
return PerDimensionProcessor.NOOP_PROCESSOR;
}

protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
return (trainingConfigValidationInput) -> {

KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext != null && knnMethodConfigContext != null) {
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
builder.valid(false);
return builder.build();
} else {
builder.valid(true);
}
}

// validate number of training points should be greater than minimum clustering criteria defined in faiss
if (knnMethodContext != null && trainingVectors != null) {
long minTrainingVectorCount = 1000;

MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}

if (trainingVectors < minTrainingVectorCount) {
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
return builder.build();
} else {
builder.valid(true);
}
}
return builder.build();
};
}

protected VectorTransformer getVectorTransformer(SpaceType spaceType) {
return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER;
}
Expand All @@ -131,6 +186,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
.perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext))
.vectorTransformer(getVectorTransformer(knnMethodContext.getSpaceType()))
.trainingConfigValidationSetup(doGetTrainingConfigValidationSetup())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Map;
import java.util.function.Function;

/**
* Context a library gives to build one of its indices
Expand Down Expand Up @@ -49,6 +50,12 @@ public interface KNNLibraryIndexingContext {
*/
PerDimensionProcessor getPerDimensionProcessor();

/**
*
* @return Get function that validates training model parameters
*/
Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> getTrainingConfigValidationSetup();

/**
* Get the vector transformer that will be used to transform the vector before indexing.
* This will be applied at vector level once entire vector is parsed and validated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.util.Collections;
import java.util.Map;
import java.util.function.Function;

/**
* Simple implementation of {@link KNNLibraryIndexingContext}
Expand All @@ -29,6 +30,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext
private Map<String, Object> parameters = Collections.emptyMap();
@Builder.Default
private QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY;
private Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> trainingConfigValidationSetup;

@Override
public Map<String, Object> getLibraryParameters() {
Expand Down Expand Up @@ -59,4 +61,9 @@ public PerDimensionValidator getPerDimensionValidator() {
public PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor;
}

@Override
public Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> getTrainingConfigValidationSetup() {
return trainingConfigValidationSetup;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

/**
* This object provides the input of the validation checks for training model inputs.
* The values in this object need to be dynamically set and calling code needs to handle
* the possibility that the values have not been set.
*/
@Setter
@Getter
@Builder
@AllArgsConstructor
public class TrainingConfigValidationInput {
private Long trainingVectorsCount;
private KNNMethodContext knnMethodContext;
private KNNMethodConfigContext knnMethodConfigContext;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

/**
* This object provides the output of the validation checks for training model inputs.
* The values in this object need to be dynamically set and calling code needs to handle
* the possibility that the values have not been set.
*/
@Setter
@Getter
@Builder
@AllArgsConstructor
public class TrainingConfigValidationOutput {
private boolean valid;
private long minTrainingVectorCount;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@
import org.opensearch.common.ValidationException;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportRequestOptions;
import org.opensearch.transport.TransportService;

import java.util.Map;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
import static org.opensearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER;
Expand Down Expand Up @@ -134,6 +140,29 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques
trainingVectors = trainingModelRequest.getMaximumVectorCount();
}

KNNMethodContext knnMethodContext = trainingModelRequest.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingModelRequest.getKnnMethodConfigContext();

KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);

Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> validateTrainingConfig = knnLibraryIndexingContext
.getTrainingConfigValidationSetup();

TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();

TrainingConfigValidationOutput validation = validateTrainingConfig.apply(
inputBuilder.trainingVectorsCount(trainingVectors).knnMethodContext(knnMethodContext).build()
);
if (!validation.isValid()) {
ValidationException exception = new ValidationException();
exception.addValidationError(
String.format("Number of training points should be greater than %d", validation.getMinTrainingVectorCount())
);
listener.onFailure(exception);
return;
}

listener.onResponse(
estimateVectorSetSizeInKB(trainingVectors, trainingModelRequest.getDimension(), trainingModelRequest.getVectorDataType())
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@
import org.opensearch.knn.index.engine.EngineResolver;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.indices.ModelDao;

import java.io.IOException;
import java.util.function.Function;

/**
* Request to train and serialize a model
Expand Down Expand Up @@ -283,6 +287,21 @@ public ActionRequestValidationException validate() {
exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + " characters");
}

KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> validateTrainingConfig = knnLibraryIndexingContext
.getTrainingConfigValidationSetup();
TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();
TrainingConfigValidationOutput validation = validateTrainingConfig.apply(
inputBuilder.knnMethodConfigContext(knnMethodConfigContext).knnMethodContext(knnMethodContext).build()
);

// Check if ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (!validation.isValid()) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions");
}

// Validate training index exists
IndexMetadata indexMetadata = clusterService.state().metadata().index(trainingIndex);
if (indexMetadata == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ public void run() {
} catch (Exception e) {
logger.error("Failed to run training job for model \"" + modelId + "\": ", e);
modelMetadata.setState(ModelState.FAILED);
modelMetadata.setError(
"Failed to execute training. May be caused by an invalid method definition or " + "not enough memory to perform training."
);
modelMetadata.setError("Failed to execute training. " + e.getMessage());

KNNCounter.TRAINING_ERRORS.increment();

Expand Down
16 changes: 8 additions & 8 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN

// training data needs to be at least equal to the number of centroids for PQ
// which is 2^8 = 256. 8 because that's the only valid code_size for HNSWPQ
int trainingDataCount = 256;
int trainingDataCount = 1100;

SpaceType spaceType = SpaceType.L2;

Expand Down Expand Up @@ -468,7 +468,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {

// training data needs to be at least equal to the number of centroids for PQ
// which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ
int trainingDataCount = 256;
int trainingDataCount = 1100;

SpaceType spaceType = SpaceType.L2;

Expand Down Expand Up @@ -736,7 +736,7 @@ public void testIVFSQFP16_whenIndexedAndQueried_thenSucceed() {

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -960,7 +960,7 @@ public void testIVFSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() {

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -1064,7 +1064,7 @@ public void testIVFSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_thenS

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -1144,7 +1144,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed(

// training data needs to be at least equal to the number of centroids for PQ
// which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ
int trainingDataCount = 256;
int trainingDataCount = 1100;

SpaceType spaceType = SpaceType.L2;

Expand Down Expand Up @@ -1414,7 +1414,7 @@ public void testKNNQuery_withModelDifferentCombination_thenSuccess() {

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

// Call train API - IVF with nlists = 1 is brute force, but will require training
Expand Down Expand Up @@ -1769,7 +1769,7 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() {

createKnnIndex(trainingIndexName, trainIndexMapping);

int trainingDataCount = 40;
int trainingDataCount = 1100;
bulkIngestRandomBinaryVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ public void testIVFByteVector_whenIndexedAndQueried_thenSucceed() {
.toString();
createKnnIndex(INDEX_NAME, trainIndexMapping);

int trainingDataCount = 100;
int trainingDataCount = 1100;
bulkIngestRandomByteVectors(INDEX_NAME, FIELD_NAME, trainingDataCount, dimension);

XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder()
Expand Down
Loading
Loading