Skip to content

Commit

Permalink
Add KNNVector DocValues Fields
Browse files Browse the repository at this point in the history
Signed-off-by: luyuncheng <[email protected]>
  • Loading branch information
luyuncheng committed Mar 27, 2024
1 parent e8c9ced commit 58cc73b
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import org.apache.lucene.index.BinaryDocValues;
import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.LeafReader;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.BytesRefBuilder;
import org.opensearch.index.fielddata.LeafFieldData;
import org.opensearch.index.fielddata.ScriptDocValues;
import org.opensearch.index.fielddata.SortedBinaryDocValues;
Expand Down Expand Up @@ -48,6 +50,44 @@ public ScriptDocValues<float[]> getScriptValues() {

@Override
public SortedBinaryDocValues getBytesValues() {
throw new UnsupportedOperationException("knn vector field '" + fieldName + "' doesn't support sorting");
try {
final BinaryDocValues binaryDocValues = DocValues.getBinary(reader, fieldName);
SortedBinaryDocValues sortedBinaryDocValues = new SortedBinaryDocValues() {

private boolean docExists = false;
float[] floats = null;
int pos = 0;
BytesRefBuilder bytesRefBuilder = new BytesRefBuilder();

@Override
public boolean advanceExact(int doc) throws IOException {
if (binaryDocValues.advanceExact(doc)) {
docExists = true;
floats = vectorDataType.getVectorFromDocValues(binaryDocValues.binaryValue());
pos = 0;
return docExists;
}
docExists = false;
return docExists;
}

@Override
public int docValueCount() {
return docExists ? floats.length : 0;
}

@Override
public BytesRef nextValue() throws IOException {
Float v = floats[pos++];
bytesRefBuilder.clear();
bytesRefBuilder.copyChars(v.toString());
return bytesRefBuilder.get();
}
};
return sortedBinaryDocValues;
} catch (IOException e) {
e.printStackTrace();
}
return null;
}
}
81 changes: 81 additions & 0 deletions src/main/java/org/opensearch/knn/index/fetch/KNNFetchSubPhase.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*/

package org.opensearch.knn.index.fetch;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.index.LeafReaderContext;
import org.opensearch.common.document.DocumentField;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.MapperService;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.search.SearchHit;
import org.opensearch.search.fetch.FetchContext;
import org.opensearch.search.fetch.FetchSubPhase;
import org.opensearch.search.fetch.FetchSubPhaseProcessor;

import java.io.IOException;
import java.util.Map;

import static org.opensearch.knn.common.KNNConstants.BYTES_PER_KILOBYTES;


public class KNNFetchSubPhase implements FetchSubPhase {
private static Logger logger = LogManager.getLogger(KNNFetchSubPhase.class);


@Override
public FetchSubPhaseProcessor getProcessor(FetchContext fetchContext) throws IOException {
return null;
}

@AllArgsConstructor
@Getter
class KNNFetchSubPhaseProcessor implements FetchSubPhaseProcessor {

private final FetchContext fetchContext;


@Override
public void setNextReader(LeafReaderContext leafReaderContext) throws IOException {

}

@Override
public void process(HitContext hitContext) throws IOException {
SearchHit hit = hitContext.hit();
Map<String, DocumentField> fields = hit.getFields();
MapperService mapperService = fetchContext.mapperService();
Map<String, Object> maps = hit.getSourceAsMap();

for (Map.Entry<String, DocumentField> fieldsEntry : fields.entrySet()) {
String fieldName = fieldsEntry.getKey();
MappedFieldType mappedFieldType = mapperService.fieldType(fieldName);
if (mappedFieldType != null && mappedFieldType instanceof KNNVectorFieldMapper.KNNVectorFieldType) {
maps.put(fieldName, fieldsEntry.getValue());
}
}

//TODO process nested
BytesStreamOutput streamOutput = new BytesStreamOutput(BYTES_PER_KILOBYTES);
XContentBuilder builder = new XContentBuilder(XContentType.JSON.xContent(), streamOutput);
builder.value(maps);
hitContext.hit().sourceRef(BytesReference.bytes(builder));
}
}
}
7 changes: 7 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import org.opensearch.indices.SystemIndexDescriptor;
import org.opensearch.knn.index.KNNCircuitBreaker;
import org.opensearch.knn.index.KNNClusterUtil;
import org.opensearch.knn.index.fetch.KNNFetchSubPhase;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
Expand Down Expand Up @@ -95,6 +96,7 @@
import org.opensearch.script.ScriptContext;
import org.opensearch.script.ScriptEngine;
import org.opensearch.script.ScriptService;
import org.opensearch.search.fetch.FetchSubPhase;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -175,6 +177,11 @@ public List<QuerySpec<?>> getQueries() {
return singletonList(new QuerySpec<>(KNNQueryBuilder.NAME, KNNQueryBuilder::new, KNNQueryBuilder::fromXContent));
}

@Override
public List<FetchSubPhase> getFetchSubPhases(FetchPhaseConstructionContext context) {
return singletonList(new KNNFetchSubPhase());
}

@Override
public Collection<Object> createComponents(
Client client,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index;

import org.opensearch.index.fielddata.SortedBinaryDocValues;
import org.opensearch.knn.KNNTestCase;
import org.apache.lucene.tests.analysis.MockAnalyzer;
import org.apache.lucene.document.BinaryDocValuesField;
Expand Down Expand Up @@ -94,6 +95,7 @@ public void testRamBytesUsed() {

public void testGetBytesValues() {
KNNVectorDVLeafFieldData leafFieldData = new KNNVectorDVLeafFieldData(leafReaderContext.reader(), "", VectorDataType.FLOAT);
expectThrows(UnsupportedOperationException.class, () -> leafFieldData.getBytesValues());

assertNotNull(leafFieldData.getBytesValues());
}
}

0 comments on commit 58cc73b

Please sign in to comment.