Skip to content

Commit

Permalink
Fixing the MediaTypeRegistry and XContent parsers due to breaking cha…
Browse files Browse the repository at this point in the history
…nges in core

Signed-off-by: Navneet Verma <[email protected]>
  • Loading branch information
navneet1v committed Aug 1, 2023
1 parent 870144f commit 761421f
Show file tree
Hide file tree
Showing 13 changed files with 67 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
import lombok.NonNull;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.store.ChecksumIndexInput;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.MediaTypeParserRegistry;
import org.opensearch.common.xcontent.XContentHelper;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.core.xcontent.DeprecationHandler;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.knn.index.KNNSettings;
Expand Down Expand Up @@ -176,9 +177,12 @@ private void createKNNIndexFromScratch(FieldInfo fieldInfo, KNNCodecUtil.Pair pa
parameters.put(PARAMETERS, algoParams);
} else {
parameters.putAll(
XContentFactory.xContent(MediaTypeParserRegistry.getDefaultMediaType())
.createParser(NamedXContentRegistry.EMPTY, DeprecationHandler.THROW_UNSUPPORTED_OPERATION, parametersString)
.map()
XContentHelper.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
new BytesArray(parametersString),
MediaTypeRegistry.getDefaultMediaType()
).map()
);
}

Expand Down
75 changes: 1 addition & 74 deletions src/test/java/org/opensearch/knn/index/FaissIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Floats;
import lombok.SneakyThrows;
import org.apache.hc.core5.http.ParseException;
import org.apache.hc.core5.http.io.entity.EntityUtils;
import org.junit.BeforeClass;
import org.opensearch.client.Response;
Expand Down Expand Up @@ -223,7 +222,7 @@ public void testDocDeletion() throws IOException {
}

@SneakyThrows
public void testKNNQuery_withModelDifferentCombination_thenSuccess() {
public void testKNNQuery_withModelDifferentCombination_thenSuccess() {
String modelId = "test-model";
int dimension = 128;

Expand Down Expand Up @@ -319,78 +318,6 @@ public void testKNNQuery_withModelDifferentCombination_thenSuccess() {
}

@SneakyThrows
<<<<<<< Updated upstream
public void testQueryWithFilter_withModelDifferentCombination_thenSuccess() {
String modelId = "test-model";
int dimension = 128;

String trainingIndexName = "train-index";
String trainingFieldName = "train-field";

// Add training data
createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension);
int trainingDataCount = 200;
bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension);

// Call train API - IVF with nlists = 1 is brute force, but will require training
XContentBuilder builder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, "ivf")
.field(KNN_ENGINE, "faiss")
.field(METHOD_PARAMETER_SPACE_TYPE, "l2")
.startObject(PARAMETERS)
.field(METHOD_PARAMETER_NLIST, 1)
.endObject()
.endObject();
Map<String, Object> method = xContentBuilderToMap(builder);

trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, "faiss test description");

// Make sure training succeeds after 30 seconds
assertTrainingSucceeds(modelId, 30, 1000);

// Create knn index from model
String fieldName = "test-field-name";
String indexName = "test-index-name";
String indexMapping = Strings.toString(
XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field(MODEL_ID, modelId)
.endObject()
.endObject()
.endObject()
);

createKnnIndex(indexName, getKNNDefaultIndexSettings(), indexMapping);

// Index some documents
int numDocs = 100;
for (int i = 0; i < numDocs; i++) {
float[] indexVector = new float[dimension];
Arrays.fill(indexVector, (float) i);
addKnnDocWithAttributes(indexName, Integer.toString(i), fieldName, indexVector, ImmutableMap.of("rating",
String.valueOf(i)));
}

// Run search and ensure that the values returned are expected
float[] queryVector = new float[dimension];
Arrays.fill(queryVector, (float) numDocs);
int k = 10;

Response searchResponse = searchKNNIndex(indexName, new KNNQueryBuilder(fieldName, queryVector, k,
QueryBuilders.rangeQuery("rating").gte("40").lte("45")), k);
List<KNNResult> results = parseSearchResponse(EntityUtils.toString(searchResponse.getEntity()), fieldName);
for (int i = 0; i < k; i++) {
assertEquals(numDocs - i - 1, Integer.parseInt(results.get(i).getDocId()));
}
}

@SneakyThrows
=======
>>>>>>> Stashed changes
public void testQueryWithFilter_withDifferentCombination_thenSuccess() {
setupKNNIndexForFilterQuery();
final float[] searchVector = { 6.0f, 6.0f, 4.1f };
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.opensearch.client.ResponseException;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.MediaTypeParserRegistry;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -62,7 +62,7 @@ public void testDeleteModelExists() throws Exception {
String responseBody = EntityUtils.toString(getModelResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody).map();
Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertEquals(modelId, responseMap.get(MODEL_ID));

Expand Down Expand Up @@ -99,7 +99,7 @@ public void testDeleteTrainingModel() throws Exception {
String responseBody = EntityUtils.toString(getModelResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody).map();
Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertEquals(modelId, responseMap.get(MODEL_ID));

Expand Down Expand Up @@ -205,7 +205,7 @@ private void trainModel(String modelId, String trainingIndexName, String trainin
String responseBody = EntityUtils.toString(getResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody).map();
Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertEquals(modelId, responseMap.get(MODEL_ID));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.core.xcontent.MediaTypeParserRegistry;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -74,7 +74,7 @@ public void testGetModelExists() throws Exception {
String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody).map();
Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();
assertEquals(modelId, responseMap.get(MODEL_ID));
assertEquals(modelDescription, responseMap.get(MODEL_DESCRIPTION));
assertEquals(FAISS.getName(), responseMap.get(KNN_ENGINE));
Expand Down Expand Up @@ -106,7 +106,7 @@ public void testGetModelExistsWithFilter() throws Exception {
String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody).map();
Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertTrue(responseMap.size() == filteredPath.size());
assertEquals(modelId, responseMap.get(MODEL_ID));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.MediaTypeParserRegistry;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
Expand Down Expand Up @@ -347,7 +347,7 @@ public void testModelIndexHealthMetricsStats() throws Exception {

final Response response = getKnnStats(Collections.emptyList(), Arrays.asList(modelIndexStatusName));
final String responseBody = EntityUtils.toString(response.getEntity());
final Map<String, Object> statsMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody).map();
final Map<String, Object> statsMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

// Check that model health status is null since model index is not created to system yet
assertNull(statsMap.get(StatNames.MODEL_INDEX_STATUS.getName()));
Expand All @@ -358,7 +358,7 @@ public void testModelIndexHealthMetricsStats() throws Exception {
Response response = getKnnStats(Collections.emptyList(), Arrays.asList(modelIndexStatusName));

final String responseBody = EntityUtils.toString(response.getEntity());
final Map<String, Object> statsMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody).map();
final Map<String, Object> statsMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

// Check that model health status is not null
assertNotNull(statsMap.get(modelIndexStatusName));
Expand Down Expand Up @@ -452,7 +452,7 @@ public void validateModelCreated(String modelId) throws Exception {
String responseBody = EntityUtils.toString(getResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody).map();
Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();
assertEquals(modelId, responseMap.get(MODEL_ID));
assertTrainingSucceeds(modelId, NUM_OF_ATTEMPTS, DELAY_MILLI_SEC);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.core.xcontent.MediaTypeParserRegistry;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;
Expand Down Expand Up @@ -73,7 +73,7 @@ public void testNoModelExists() throws Exception {
String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);

XContentParser parser = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody);
XContentParser parser = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody);
SearchResponse searchResponse = SearchResponse.fromXContent(parser);
assertNotNull(searchResponse);
assertEquals(searchResponse.getHits().getHits().length, 0);
Expand Down Expand Up @@ -133,7 +133,7 @@ public void testSearchModelExists() throws Exception {
String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);

XContentParser parser = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody);
XContentParser parser = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody);
SearchResponse searchResponse = SearchResponse.fromXContent(parser);
assertNotNull(searchResponse);

Expand Down Expand Up @@ -177,7 +177,7 @@ public void testSearchModelWithoutSource() throws Exception {
String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);

XContentParser parser = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody);
XContentParser parser = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody);
SearchResponse searchResponse = SearchResponse.fromXContent(parser);
assertNotNull(searchResponse);

Expand Down Expand Up @@ -225,7 +225,7 @@ public void testSearchModelWithSourceFilteringIncludes() throws Exception {
String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);

XContentParser parser = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody);
XContentParser parser = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody);
SearchResponse searchResponse = SearchResponse.fromXContent(parser);
assertNotNull(searchResponse);

Expand Down Expand Up @@ -277,7 +277,7 @@ public void testSearchModelWithSourceFilteringExcludes() throws Exception {
String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);

XContentParser parser = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody);
XContentParser parser = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody);
SearchResponse searchResponse = SearchResponse.fromXContent(parser);
assertNotNull(searchResponse);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import org.opensearch.client.Response;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.xcontent.MediaTypeParserRegistry;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.knn.KNNRestTestCase;
import org.opensearch.core.rest.RestStatus;

Expand Down Expand Up @@ -96,8 +96,7 @@ public void testTrainModel_fail_notEnoughData() throws Exception {
String trainResponseBody = EntityUtils.toString(trainResponse.getEntity());
assertNotNull(trainResponseBody);

Map<String, Object> trainResponseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), trainResponseBody)
.map();
Map<String, Object> trainResponseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), trainResponseBody).map();
String modelId = (String) trainResponseMap.get(MODEL_ID);
assertNotNull(modelId);

Expand All @@ -106,7 +105,7 @@ public void testTrainModel_fail_notEnoughData() throws Exception {
String responseBody = EntityUtils.toString(getResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody).map();
Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertEquals(modelId, responseMap.get(MODEL_ID));

Expand Down Expand Up @@ -177,8 +176,7 @@ public void testTrainModel_fail_tooMuchData() throws Exception {
String trainResponseBody = EntityUtils.toString(trainResponse.getEntity());
assertNotNull(trainResponseBody);

Map<String, Object> trainResponseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), trainResponseBody)
.map();
Map<String, Object> trainResponseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), trainResponseBody).map();
String modelId = (String) trainResponseMap.get(MODEL_ID);
assertNotNull(modelId);

Expand All @@ -187,7 +185,7 @@ public void testTrainModel_fail_tooMuchData() throws Exception {
String responseBody = EntityUtils.toString(getResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody).map();
Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertEquals(modelId, responseMap.get(MODEL_ID));

Expand Down Expand Up @@ -258,7 +256,7 @@ public void testTrainModel_success_withId() throws Exception {
String responseBody = EntityUtils.toString(getResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody).map();
Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertEquals(modelId, responseMap.get(MODEL_ID));

Expand Down Expand Up @@ -328,8 +326,7 @@ public void testTrainModel_success_noId() throws Exception {
String trainResponseBody = EntityUtils.toString(trainResponse.getEntity());
assertNotNull(trainResponseBody);

Map<String, Object> trainResponseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), trainResponseBody)
.map();
Map<String, Object> trainResponseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), trainResponseBody).map();
String modelId = (String) trainResponseMap.get(MODEL_ID);
assertNotNull(modelId);

Expand All @@ -338,7 +335,7 @@ public void testTrainModel_success_noId() throws Exception {
String responseBody = EntityUtils.toString(getResponse.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(MediaTypeParserRegistry.getDefaultMediaType().xContent(), responseBody).map();
Map<String, Object> responseMap = createParser(MediaTypeRegistry.getDefaultMediaType().xContent(), responseBody).map();

assertEquals(modelId, responseMap.get(MODEL_ID));

Expand Down
Loading

0 comments on commit 761421f

Please sign in to comment.