Skip to content

Commit

Permalink
Add more unit-tests for Lucene Byte Vector
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda committed Jul 13, 2023
1 parent afdf125 commit 1dea3a8
Show file tree
Hide file tree
Showing 8 changed files with 217 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public static class Builder extends ParametrizedFieldMapper.Builder {
* data_type which defines the datatype of the vector values. This is an optional parameter and
* this is right now only relevant for lucene engine. The default value is float.
*/
private final Parameter<VectorDataType> vectorDataType = new Parameter<>(
protected final Parameter<VectorDataType> vectorDataType = new Parameter<>(
VECTOR_DATA_TYPE_FIELD,
false,
() -> DEFAULT_VECTOR_DATA_TYPE_FIELD,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,12 @@ public void tearDown() throws Exception {
directory.close();
}

public void testGetScriptValues() {
KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(
leafReaderContext.reader(),
MOCK_INDEX_FIELD_NAME,
VectorDataType.FLOAT
);
ScriptDocValues<float[]> scriptValues = leafFieldData.getScriptValues();
assertNotNull(scriptValues);
assertTrue(scriptValues instanceof KNNVectorScriptDocValues);
public void testGetScriptValuesFloatVectorDataType() {
validateGetScriptValuesWithVectorDataType(VectorDataType.FLOAT);
}

public void testGetScriptValuesByteVectorDataType() {
validateGetScriptValuesWithVectorDataType(VectorDataType.BYTE);
}

public void testGetScriptValuesWrongFieldName() {
Expand All @@ -87,6 +84,17 @@ public void testGetScriptValuesWrongFieldType() {
expectThrows(IllegalStateException.class, () -> leafFieldData.getScriptValues());
}

private void validateGetScriptValuesWithVectorDataType(VectorDataType vectorDataType) {
KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(
leafReaderContext.reader(),
MOCK_INDEX_FIELD_NAME,
vectorDataType
);
ScriptDocValues<float[]> scriptValues = leafFieldData.getScriptValues();
assertNotNull(scriptValues);
assertTrue(scriptValues instanceof KNNVectorScriptDocValues);
}

public void testRamBytesUsed() {
KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "", VectorDataType.FLOAT);
assertEquals(0, leafFieldData.ramBytesUsed());
Expand Down
58 changes: 58 additions & 0 deletions src/test/java/org/opensearch/knn/index/VectorDataTypeTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.junit.Assert;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapperUtil;

import java.io.IOException;
import java.util.Locale;

import static org.opensearch.knn.common.KNNConstants.VECTOR_DATA_TYPE_FIELD;

public class VectorDataTypeTests extends KNNTestCase {

Expand Down Expand Up @@ -51,6 +55,60 @@ public void testGetDocValuesWithByteVectorDataType() {
directory.close();
}

public void testFloatVectorValueValidations() {
// Validate Float Vector Value which is NaN and throws exception
IllegalArgumentException ex = expectThrows(
IllegalArgumentException.class,
() -> KNNVectorFieldMapperUtil.validateFloatVectorValue(Float.NaN)
);
assertTrue(ex.getMessage().contains("KNN vector values cannot be NaN"));

// Validate Float Vector Value which is infinite and throws exception
IllegalArgumentException ex1 = expectThrows(
IllegalArgumentException.class,
() -> KNNVectorFieldMapperUtil.validateFloatVectorValue(Float.POSITIVE_INFINITY)
);
assertTrue(ex1.getMessage().contains("KNN vector values cannot be infinity"));
}

public void testByteVectorValueValidations() {
// Validate Byte Vector Value which is float with decimal values and throws exception
IllegalArgumentException ex = expectThrows(
IllegalArgumentException.class,
() -> KNNVectorFieldMapperUtil.validateByteVectorValue(10.54f)
);
assertTrue(
ex.getMessage()
.contains(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are floats instead of byte integers",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue()
)
)
);

// Validate Byte Vector Value which is not in the byte range and throws exception
IllegalArgumentException ex1 = expectThrows(
IllegalArgumentException.class,
() -> KNNVectorFieldMapperUtil.validateByteVectorValue(200f)
);
assertTrue(
ex1.getMessage()
.contains(
String.format(
Locale.ROOT,
"[%s] field was set as [%s] in index mapping. But, KNN vector values are not within in the byte range [%d, %d]",
VECTOR_DATA_TYPE_FIELD,
VectorDataType.BYTE.getValue(),
Byte.MIN_VALUE,
Byte.MAX_VALUE
)
)
);
}

@SneakyThrows
private KNNVectorScriptDocValues getKNNFloatVectorScriptDocValues() {
directory = newDirectory();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import com.google.common.collect.ImmutableMap;
import lombok.SneakyThrows;
import org.apache.lucene.document.FieldType;
import org.apache.lucene.document.KnnByteVectorField;
import org.apache.lucene.document.KnnVectorField;
import org.apache.lucene.index.DocValuesType;
import org.apache.lucene.index.IndexableField;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.BytesRef;
Expand Down Expand Up @@ -218,6 +220,7 @@ public void testBuilder_parse_fromKnnMethodContext_luceneEngine() throws IOExcep
.startObject()
.field(TYPE_FIELD_NAME, KNN_VECTOR_TYPE)
.field(DIMENSION_FIELD_NAME, dimension)
.field(VECTOR_DATA_TYPE_FIELD, VectorDataType.BYTE.getValue())
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(METHOD_PARAMETER_SPACE_TYPE, SpaceType.L2)
Expand All @@ -237,6 +240,7 @@ public void testBuilder_parse_fromKnnMethodContext_luceneEngine() throws IOExcep
builder.build(builderContext);

assertEquals(METHOD_HNSW, builder.knnMethodContext.get().getMethodComponent().getName());
assertEquals(VectorDataType.BYTE.getValue(), builder.vectorDataType.getValue().getValue());
assertEquals(
efConstruction,
builder.knnMethodContext.get().getMethodComponent().getParameters().get(METHOD_PARAMETER_EF_CONSTRUCTION)
Expand Down Expand Up @@ -871,6 +875,13 @@ public void testLuceneFieldMapper_parseCreateField_docValues_withBytes() {
assertArrayEquals(TEST_BYTE_VECTOR, knnByteVectorField.vectorValue());
}

public void testBuildDocValuesFieldType() {
FieldType fieldType = KNNVectorFieldMapperUtil.buildDocValuesFieldType(KNNEngine.LUCENE);
assertNotNull(fieldType);
assertEquals(KNNEngine.LUCENE.getName(), fieldType.getAttributes().get(KNN_ENGINE));
assertEquals(DocValuesType.BINARY, fieldType.docValuesType());
}

private LuceneFieldMapper.CreateLuceneFieldMapperInput.CreateLuceneFieldMapperInputBuilder createLuceneFieldMapperInputBuilder(
VectorDataType vectorDataType
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import com.google.common.collect.ImmutableMap;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.opensearch.Version;
Expand Down Expand Up @@ -146,7 +147,7 @@ protected NamedWriteableRegistry writableRegistry() {
return new NamedWriteableRegistry(entries);
}

public void testDoToQuery_Normal() throws Exception {
public void testDoToQuery_Normal() {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
Index dummyIndex = new Index("dummy", "dummy");
Expand All @@ -162,6 +163,26 @@ public void testDoToQuery_Normal() throws Exception {
assertEquals(knnQueryBuilder.vector(), query.getQueryVector());
}

public void testDoToQuery_Normal_ByteVectorDataType() {
// Validate doToQuery with Byte vector data type
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
Index dummyIndex = new Index("dummy", "dummy");
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
KNNMethodContext mockKNNMethodContext = mock(KNNMethodContext.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockKNNVectorField.getKnnMethodContext()).thenReturn(mockKNNMethodContext);
when(mockKNNVectorField.getVectorDataType()).thenReturn(VectorDataType.BYTE);
when(mockKNNMethodContext.getKnnEngine()).thenReturn(KNNEngine.LUCENE);
when(mockQueryShardContext.fieldMapper(anyString())).thenReturn(mockKNNVectorField);

Query query = knnQueryBuilder.doToQuery(mockQueryShardContext);
assertNotNull(query);
assertTrue(query.getClass().isAssignableFrom(KnnByteVectorQuery.class));
}

public void testDoToQuery_KnnQueryWithFilter() throws Exception {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K, TERM_QUERY);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.knn.index.query;

import org.apache.lucene.index.Term;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
Expand All @@ -15,6 +16,7 @@
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.knn.KNNTestCase;
import org.opensearch.knn.index.VectorDataType;
import org.opensearch.knn.index.util.KNNEngine;

import java.util.Arrays;
Expand Down Expand Up @@ -73,6 +75,24 @@ public void testCreateLuceneDefaultQuery() {
}
}

public void testCreateLuceneQueryByteVectorDataType() {
byte[] byteQueryVector = { 1, 2, 3, 4 };
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNQueryFactory.CreateQueryRequest createQueryRequest = KNNQueryFactory.CreateQueryRequest.builder()
.knnEngine(KNNEngine.LUCENE)
.indexName(testIndexName)
.fieldName(testFieldName)
.vector(null)
.byteVector(byteQueryVector)
.vectorDataType(VectorDataType.BYTE)
.k(testK)
.filter(null)
.context(mockQueryShardContext)
.build();
Query query = KNNQueryFactory.create(createQueryRequest);
assertTrue(query.getClass().isAssignableFrom(KnnByteVectorQuery.class));
}

public void testCreateLuceneQueryWithFilter() {
List<KNNEngine> luceneDefaultQueryEngineList = Arrays.stream(KNNEngine.values())
.filter(knnEngine -> !KNNEngine.getEnginesThatCreateCustomSegmentFiles().contains(knnEngine))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,27 @@ public void testParseKNNVectorQuery() {
String invalidObject = "invalidObject";
expectThrows(ClassCastException.class, () -> KNNScoringSpaceUtil.parseToFloatArray(invalidObject, 3, VectorDataType.FLOAT));
}

public void testParseKNNVectorQueryByteVectorDataType() {
float[] arrayFloat = new float[] { 1.0f, 2.0f, 3.0f };
List<Number> arrayListQueryObject = new ArrayList<>(Arrays.asList(1, 2, 3));
KNNVectorFieldMapper.KNNVectorFieldType fieldType = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(fieldType.getDimension()).thenReturn(3);
// Query vector is a byte vector, so test should succeed
assertArrayEquals(arrayFloat, KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject, 3, VectorDataType.BYTE), 0.1f);

// Query vector is a float vector for byte vector data type, so test should throw IllegalArgumentException
List<Double> arrayListQueryObject1 = new ArrayList<>(Arrays.asList(1.1, 2.56, 3.67));
expectThrows(
IllegalArgumentException.class,
() -> KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject1, 3, VectorDataType.BYTE)
);

// Query vector is not within the byte range for byte vector data type, so test should throw IllegalArgumentException
List<Number> arrayListQueryObject2 = new ArrayList<>(Arrays.asList(1000, 2, 3));
expectThrows(
IllegalArgumentException.class,
() -> KNNScoringSpaceUtil.parseToFloatArray(arrayListQueryObject2, 3, VectorDataType.BYTE)
);
}
}
Loading

0 comments on commit 1dea3a8

Please sign in to comment.