Skip to content

Commit

Permalink
Added more detailed error messages for KNN model training (#2378)
Browse files Browse the repository at this point in the history
* Add more detailed error messages for KNN model training

Signed-off-by: AnnTian Shao <[email protected]>

* Add validation check for training parameters in engine method abstraction

Signed-off-by: AnnTian Shao <[email protected]>

* Fixes for bwc and IT tests

Signed-off-by: AnnTian Shao <[email protected]>

---------

Signed-off-by: AnnTian Shao <[email protected]>
Co-authored-by: AnnTian Shao <[email protected]>
(cherry picked from commit 4058a53)
  • Loading branch information
anntians authored and AnnTian Shao committed Jan 25, 2025
1 parent d142366 commit 1982169
Show file tree
Hide file tree
Showing 20 changed files with 385 additions and 46 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Introduced a writing layer in native engines where relies on the writing interface to process IO. (#2241)[https://github.com/opensearch-project/k-NN/pull/2241]
- Allow method parameter override for training based indices (#2290) https://github.com/opensearch-project/k-NN/pull/2290]
- Optimizes lucene query execution to prevent unnecessary rewrites (#2305)[https://github.com/opensearch-project/k-NN/pull/2305]
- Added more detailed error messages for KNN model training (#2378)[https://github.com/opensearch-project/k-NN/pull/2378]
- Add check to directly use ANN Search when filters match all docs. (#2320)[https://github.com/opensearch-project/k-NN/pull/2320]
- Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357]
- Add WithFieldName implementation to KNNQueryBuilder (#2398)[https://github.com/opensearch-project/k-NN/pull/2398]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ public void testIVFSQFP16_onUpgradeWhenIndexedAndQueried_thenSucceed() throws Ex

// Add training data
createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, DIMENSION);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, DIMENSION);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -279,7 +279,7 @@ public void testIVFSQFP16_onUpgradeWhenClipToFp16isTrueAndIndexedWithOutOfFP16Ra

// Add training data
createBasicKnnIndex(TRAIN_INDEX, TRAIN_TEST_FIELD, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(TRAIN_INDEX, TRAIN_TEST_FIELD, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M;
import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER;
import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST;
import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE;

/**
* Abstract class for KNN methods. This class provides the common functionality for all KNN methods.
Expand Down Expand Up @@ -108,6 +114,55 @@ protected PerDimensionProcessor doGetPerDimensionProcessor(
return PerDimensionProcessor.NOOP_PROCESSOR;
}

protected Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> doGetTrainingConfigValidationSetup() {
return (trainingConfigValidationInput) -> {

KNNMethodContext knnMethodContext = trainingConfigValidationInput.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingConfigValidationInput.getKnnMethodConfigContext();
Long trainingVectors = trainingConfigValidationInput.getTrainingVectorsCount();

TrainingConfigValidationOutput.TrainingConfigValidationOutputBuilder builder = TrainingConfigValidationOutput.builder();

// validate ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (knnMethodContext != null && knnMethodConfigContext != null) {
if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(ENCODER_PARAMETER_PQ_M)
&& knnMethodConfigContext.getDimension() % (Integer) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(ENCODER_PARAMETER_PQ_M) != 0) {
builder.valid(false);
return builder.build();
} else {
builder.valid(true);
}
}

// validate number of training points should be greater than minimum clustering criteria defined in faiss
if (knnMethodContext != null && trainingVectors != null) {
long minTrainingVectorCount = 1000;

MethodComponentContext encoderContext = (MethodComponentContext) knnMethodContext.getMethodComponentContext()
.getParameters()
.get(METHOD_ENCODER_PARAMETER);

if (knnMethodContext.getMethodComponentContext().getParameters().containsKey(METHOD_PARAMETER_NLIST)
&& encoderContext.getParameters().containsKey(ENCODER_PARAMETER_PQ_CODE_SIZE)) {

int nlist = ((Integer) knnMethodContext.getMethodComponentContext().getParameters().get(METHOD_PARAMETER_NLIST));
int code_size = ((Integer) encoderContext.getParameters().get(ENCODER_PARAMETER_PQ_CODE_SIZE));
minTrainingVectorCount = (long) Math.max(nlist, Math.pow(2, code_size));
}

if (trainingVectors < minTrainingVectorCount) {
builder.valid(false).minTrainingVectorCount(minTrainingVectorCount);
return builder.build();
} else {
builder.valid(true);
}
}
return builder.build();
};
}

protected VectorTransformer getVectorTransformer(SpaceType spaceType) {
return VectorTransformerFactory.NOOP_VECTOR_TRANSFORMER;
}
Expand All @@ -131,6 +186,7 @@ public KNNLibraryIndexingContext getKNNLibraryIndexingContext(
.perDimensionValidator(doGetPerDimensionValidator(knnMethodContext, knnMethodConfigContext))
.perDimensionProcessor(doGetPerDimensionProcessor(knnMethodContext, knnMethodConfigContext))
.vectorTransformer(getVectorTransformer(knnMethodContext.getSpaceType()))
.trainingConfigValidationSetup(doGetTrainingConfigValidationSetup())
.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import org.opensearch.knn.index.mapper.VectorValidator;

import java.util.Map;
import java.util.function.Function;

/**
* Context a library gives to build one of its indices
Expand Down Expand Up @@ -49,6 +50,12 @@ public interface KNNLibraryIndexingContext {
*/
PerDimensionProcessor getPerDimensionProcessor();

/**
*
* @return Get function that validates training model parameters
*/
Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> getTrainingConfigValidationSetup();

/**
* Get the vector transformer that will be used to transform the vector before indexing.
* This will be applied at vector level once entire vector is parsed and validated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import java.util.Collections;
import java.util.Map;
import java.util.function.Function;

/**
* Simple implementation of {@link KNNLibraryIndexingContext}
Expand All @@ -29,6 +30,7 @@ public class KNNLibraryIndexingContextImpl implements KNNLibraryIndexingContext
private Map<String, Object> parameters = Collections.emptyMap();
@Builder.Default
private QuantizationConfig quantizationConfig = QuantizationConfig.EMPTY;
private Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> trainingConfigValidationSetup;

@Override
public Map<String, Object> getLibraryParameters() {
Expand Down Expand Up @@ -59,4 +61,9 @@ public PerDimensionValidator getPerDimensionValidator() {
public PerDimensionProcessor getPerDimensionProcessor() {
return perDimensionProcessor;
}

@Override
public Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> getTrainingConfigValidationSetup() {
return trainingConfigValidationSetup;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

/**
* This object provides the input of the validation checks for training model inputs.
* The values in this object need to be dynamically set and calling code needs to handle
* the possibility that the values have not been set.
*/
@Setter
@Getter
@Builder
@AllArgsConstructor
public class TrainingConfigValidationInput {
private Long trainingVectorsCount;
private KNNMethodContext knnMethodContext;
private KNNMethodConfigContext knnMethodConfigContext;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index.engine;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.Setter;

/**
* This object provides the output of the validation checks for training model inputs.
* The values in this object need to be dynamically set and calling code needs to handle
* the possibility that the values have not been set.
*/
@Setter
@Getter
@Builder
@AllArgsConstructor
public class TrainingConfigValidationOutput {
private boolean valid;
private long minTrainingVectorCount;
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,18 @@
import org.opensearch.common.ValidationException;
import org.opensearch.common.inject.Inject;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.KNNMethodConfigContext;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportRequestOptions;
import org.opensearch.transport.TransportService;

import java.util.Map;
import java.util.function.Function;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;
import static org.opensearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER;
Expand Down Expand Up @@ -134,6 +140,29 @@ protected void getTrainingIndexSizeInKB(TrainingModelRequest trainingModelReques
trainingVectors = trainingModelRequest.getMaximumVectorCount();
}

KNNMethodContext knnMethodContext = trainingModelRequest.getKnnMethodContext();
KNNMethodConfigContext knnMethodConfigContext = trainingModelRequest.getKnnMethodConfigContext();

KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);

Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> validateTrainingConfig = knnLibraryIndexingContext
.getTrainingConfigValidationSetup();

TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();

TrainingConfigValidationOutput validation = validateTrainingConfig.apply(
inputBuilder.trainingVectorsCount(trainingVectors).knnMethodContext(knnMethodContext).build()
);
if (!validation.isValid()) {
ValidationException exception = new ValidationException();
exception.addValidationError(
String.format("Number of training points should be greater than %d", validation.getMinTrainingVectorCount())
);
listener.onFailure(exception);
return;
}

listener.onResponse(
estimateVectorSetSizeInKB(trainingVectors, trainingModelRequest.getDimension(), trainingModelRequest.getVectorDataType())
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,14 @@
import org.opensearch.knn.index.engine.EngineResolver;
import org.opensearch.knn.index.util.IndexUtil;
import org.opensearch.knn.index.engine.KNNMethodContext;
import org.opensearch.knn.index.engine.KNNLibraryIndexingContext;
import org.opensearch.knn.index.engine.TrainingConfigValidationInput;
import org.opensearch.knn.index.engine.TrainingConfigValidationOutput;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.indices.ModelDao;

import java.io.IOException;
import java.util.function.Function;

/**
* Request to train and serialize a model
Expand Down Expand Up @@ -283,6 +287,21 @@ public ActionRequestValidationException validate() {
exception.addValidationError("Description exceeds limit of " + KNNConstants.MAX_MODEL_DESCRIPTION_LENGTH + " characters");
}

KNNLibraryIndexingContext knnLibraryIndexingContext = knnMethodContext.getKnnEngine()
.getKNNLibraryIndexingContext(knnMethodContext, knnMethodConfigContext);
Function<TrainingConfigValidationInput, TrainingConfigValidationOutput> validateTrainingConfig = knnLibraryIndexingContext
.getTrainingConfigValidationSetup();
TrainingConfigValidationInput.TrainingConfigValidationInputBuilder inputBuilder = TrainingConfigValidationInput.builder();
TrainingConfigValidationOutput validation = validateTrainingConfig.apply(
inputBuilder.knnMethodConfigContext(knnMethodConfigContext).knnMethodContext(knnMethodContext).build()
);

// Check if ENCODER_PARAMETER_PQ_M is divisible by vector dimension
if (!validation.isValid()) {
exception = exception == null ? new ActionRequestValidationException() : exception;
exception.addValidationError("Training request ENCODER_PARAMETER_PQ_M is not divisible by vector dimensions");
}

// Validate training index exists
IndexMetadata indexMetadata = clusterService.state().metadata().index(trainingIndex);
if (indexMetadata == null) {
Expand Down
4 changes: 1 addition & 3 deletions src/main/java/org/opensearch/knn/training/TrainingJob.java
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,7 @@ public void run() {
} catch (Exception e) {
logger.error("Failed to run training job for model \"" + modelId + "\": ", e);
modelMetadata.setState(ModelState.FAILED);
modelMetadata.setError(
"Failed to execute training. May be caused by an invalid method definition or " + "not enough memory to perform training."
);
modelMetadata.setError("Failed to execute training. " + e.getMessage());

KNNCounter.TRAINING_ERRORS.increment();

Expand Down
16 changes: 8 additions & 8 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ public void testEndToEnd_whenDoRadiusSearch_whenDistanceThreshold_whenMethodIsHN

// training data needs to be at least equal to the number of centroids for PQ
// which is 2^8 = 256. 8 because that's the only valid code_size for HNSWPQ
int trainingDataCount = 256;
int trainingDataCount = 1100;

SpaceType spaceType = SpaceType.L2;

Expand Down Expand Up @@ -468,7 +468,7 @@ public void testEndToEnd_whenMethodIsHNSWPQ_thenSucceed() {

// training data needs to be at least equal to the number of centroids for PQ
// which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ
int trainingDataCount = 256;
int trainingDataCount = 1100;

SpaceType spaceType = SpaceType.L2;

Expand Down Expand Up @@ -736,7 +736,7 @@ public void testIVFSQFP16_whenIndexedAndQueried_thenSucceed() {

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -960,7 +960,7 @@ public void testIVFSQFP16_whenIndexedWithOutOfFP16Range_thenThrowException() {

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -1064,7 +1064,7 @@ public void testIVFSQFP16_whenClipToFp16isTrueAndIndexedWithOutOfFP16Range_thenS

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder builder = XContentFactory.jsonBuilder()
Expand Down Expand Up @@ -1144,7 +1144,7 @@ public void testEndToEnd_whenMethodIsHNSWPQAndHyperParametersNotSet_thenSucceed(

// training data needs to be at least equal to the number of centroids for PQ
// which is 2^8 = 256. 8 because thats the only valid code_size for HNSWPQ
int trainingDataCount = 256;
int trainingDataCount = 1100;

SpaceType spaceType = SpaceType.L2;

Expand Down Expand Up @@ -1414,7 +1414,7 @@ public void testKNNQuery_withModelDifferentCombination_thenSuccess() {

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
int trainingDataCount = 1100;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

// Call train API - IVF with nlists = 1 is brute force, but will require training
Expand Down Expand Up @@ -1769,7 +1769,7 @@ public void testIVF_whenBinaryFormat_whenIVF_thenSuccess() {

createKnnIndex(trainingIndexName, trainIndexMapping);

int trainingDataCount = 40;
int trainingDataCount = 1100;
bulkIngestRandomBinaryVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -603,7 +603,7 @@ public void testIVFByteVector_whenIndexedAndQueried_thenSucceed() {
.toString();
createKnnIndex(INDEX_NAME, trainIndexMapping);

int trainingDataCount = 100;
int trainingDataCount = 1100;
bulkIngestRandomByteVectors(INDEX_NAME, FIELD_NAME, trainingDataCount, dimension);

XContentBuilder trainModelXContentBuilder = XContentFactory.jsonBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ public void testKNNScriptScoreOnModelBasedIndex() throws Exception {
int dimensions = randomIntBetween(2, 10);
String trainMapping = createKnnIndexMapping(TRAIN_FIELD_PARAMETER, dimensions);
createKnnIndex(TRAIN_INDEX_PARAMETER, trainMapping);
bulkIngestRandomVectors(TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, dimensions * 3, dimensions);
bulkIngestRandomVectors(TRAIN_INDEX_PARAMETER, TRAIN_FIELD_PARAMETER, 1100, dimensions);

XContentBuilder methodBuilder = XContentFactory.jsonBuilder()
.startObject()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public class ModeAndCompressionIT extends KNNRestTestCase {

private static final String TRAINING_INDEX_NAME = "training_index";
private static final String TRAINING_FIELD_NAME = "training_field";
private static final int TRAINING_VECS = 20;
private static final int TRAINING_VECS = 1100;

private static final int DIMENSION = 16;
private static final int NUM_DOCS = 20;
Expand Down
Loading

0 comments on commit 1982169

Please sign in to comment.