Skip to content

Commit

Permalink
Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (
Browse files Browse the repository at this point in the history
…#2408)

* Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter

Signed-off-by: Wei Wang <[email protected]>

* fix typo error in test file

Signed-off-by: Wei Wang <[email protected]>

---------

Signed-off-by: Wei Wang <[email protected]>
Signed-off-by: Wei Wang <[email protected]>
  • Loading branch information
weiwang118 authored Jan 23, 2025
1 parent 1c4a7ca commit d58d133
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Use one formula to calculate cosine similarity (#2357)[https://github.com/opensearch-project/k-NN/pull/2357]
- Add WithFieldName implementation to KNNQueryBuilder (#2398)[https://github.com/opensearch-project/k-NN/pull/2398]
- Make the build work for M series MacOS without manual code changes and local JAVA_HOME config (#2397)[https://github.com/opensearch-project/k-NN/pull/2397]
- Remove DocsWithFieldSet reference from NativeEngineFieldVectorsWriter (#2408)[https://github.com/opensearch-project/k-NN/pull/2408]
### Bug Fixes
* Fixing the bug when a segment has no vector field present for disk based vector search (#2282)[https://github.com/opensearch-project/k-NN/pull/2282]
* Fixing the bug where search fails with "fields" parameter for an index with a knn_vector field (#2314)[https://github.com/opensearch-project/k-NN/pull/2314]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import lombok.Getter;
import org.apache.lucene.codecs.KnnFieldVectorsWriter;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.util.InfoStream;
import org.apache.lucene.util.RamUsageEstimator;
Expand Down Expand Up @@ -43,9 +42,8 @@ class NativeEngineFieldVectorsWriter<T> extends KnnFieldVectorsWriter<T> {
@Getter
private final Map<Integer, T> vectors;
private int lastDocID = -1;
@Getter
private final DocsWithFieldSet docsWithField;
private final InfoStream infoStream;
@Getter
private final FlatFieldVectorsWriter<T> flatFieldVectorsWriter;

@SuppressWarnings("unchecked")
Expand Down Expand Up @@ -75,7 +73,6 @@ private NativeEngineFieldVectorsWriter(
this.fieldInfo = fieldInfo;
this.infoStream = infoStream;
vectors = new HashMap<>();
this.docsWithField = new DocsWithFieldSet();
this.flatFieldVectorsWriter = flatFieldVectorsWriter;
}

Expand All @@ -101,7 +98,6 @@ public void addValue(int docID, T vectorValue) throws IOException {
// ensuring that vector is provided to flatFieldWriter.
flatFieldVectorsWriter.addValue(docID, vectorValue);
vectors.put(docID, vectorValue);
docsWithField.add(docID);
lastDocID = docID;
}

Expand All @@ -121,10 +117,9 @@ public T copyValue(T vectorValue) {
*/
@Override
public long ramBytesUsed() {
return SHALLOW_SIZE + docsWithField.ramBytesUsed() + (long) this.vectors.size() * (long) (RamUsageEstimator.NUM_BYTES_OBJECT_REF
+ RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) this.vectors.size() * RamUsageEstimator.shallowSizeOfInstance(
Integer.class
) + (long) vectors.size() * fieldInfo.getVectorDimension() * fieldInfo.getVectorEncoding().byteSize + flatFieldVectorsWriter
.ramBytesUsed();
return SHALLOW_SIZE + flatFieldVectorsWriter.getDocsWithFieldSet().ramBytesUsed() + (long) this.vectors.size()
* (long) (RamUsageEstimator.NUM_BYTES_OBJECT_REF + RamUsageEstimator.NUM_BYTES_ARRAY_HEADER) + (long) this.vectors.size()
* RamUsageEstimator.shallowSizeOfInstance(Integer.class) + (long) vectors.size() * fieldInfo.getVectorDimension()
* fieldInfo.getVectorEncoding().byteSize + flatFieldVectorsWriter.ramBytesUsed();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ public void flush(int maxDoc, final Sorter.DocMap sortMap) throws IOException {
}
final Supplier<KNNVectorValues<?>> knnVectorValuesSupplier = () -> getVectorValues(
vectorDataType,
field.getDocsWithField(),
field.getFlatFieldVectorsWriter().getDocsWithFieldSet(),
field.getVectors()
);
final QuantizationState quantizationState = train(field.getFieldInfo(), knnVectorValuesSupplier, totalLiveDocs);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import lombok.SneakyThrows;
import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter;
import org.apache.lucene.index.DocsWithFieldSet;
import org.apache.lucene.index.FieldInfo;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.util.InfoStream;
Expand Down Expand Up @@ -115,6 +116,7 @@ public void testRamByteUsed_whenValidInput_thenSuccess() {
Mockito.when(fieldInfo.getVectorDimension()).thenReturn(2);
FlatFieldVectorsWriter<?> mockedFlatFieldVectorsWriter = Mockito.mock(FlatFieldVectorsWriter.class);
Mockito.when(mockedFlatFieldVectorsWriter.ramBytesUsed()).thenReturn(1L);
Mockito.when(mockedFlatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(new DocsWithFieldSet());
final NativeEngineFieldVectorsWriter<float[]> floatWriter = (NativeEngineFieldVectorsWriter<float[]>) NativeEngineFieldVectorsWriter
.create(fieldInfo, mockedFlatFieldVectorsWriter, InfoStream.getDefault());
// testing for value > 0 as we don't have a concrete way to find out expected bytes. This can OS dependent too.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ public void testFlush() {
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -250,7 +250,7 @@ public void testFlush_WithQuantization() {
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -352,7 +352,7 @@ public void testFlush_whenThresholdIsNegative_thenNativeIndexWriterIsNeverCalled
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -429,7 +429,7 @@ public void testFlush_whenThresholdIsGreaterThanVectorSize_thenNativeIndexWriter
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -507,7 +507,7 @@ public void testFlush_whenThresholdIsEqualToMinNumberOfVectors_thenNativeIndexWr
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -593,7 +593,7 @@ public void testFlush_whenThresholdIsEqualToFixedValue_thenRelevantNativeIndexWr
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -683,7 +683,7 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -786,7 +786,7 @@ public void testFlush_whenQuantizationIsProvided_whenBuildGraphDatStructureThres
throw new RuntimeException(e);
}

DocsWithFieldSet docsWithFieldSet = field.getDocsWithField();
DocsWithFieldSet docsWithFieldSet = field.getFlatFieldVectorsWriter().getDocsWithFieldSet();
knnVectorValuesFactoryMockedStatic.when(
() -> KNNVectorValuesFactory.getVectorValues(VectorDataType.FLOAT, docsWithFieldSet, vectorsPerField.get(i))
).thenReturn(expectedVectorValues.get(i));
Expand Down Expand Up @@ -848,11 +848,13 @@ private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map<

private <T> NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map<Integer, T> vectors) {
NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class);
FlatFieldVectorsWriter flatFieldVectorsWriter = mock(FlatFieldVectorsWriter.class);
DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet();
vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add);
when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo);
when(fieldVectorsWriter.getVectors()).thenReturn(vectors);
when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet);
when(fieldVectorsWriter.getFlatFieldVectorsWriter()).thenReturn(flatFieldVectorsWriter);
when(flatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(docsWithFieldSet);
return fieldVectorsWriter;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -370,11 +370,13 @@ private FieldInfo fieldInfo(int fieldNumber, VectorEncoding vectorEncoding, Map<

private <T> NativeEngineFieldVectorsWriter nativeEngineFieldVectorsWriter(FieldInfo fieldInfo, Map<Integer, T> vectors) {
NativeEngineFieldVectorsWriter fieldVectorsWriter = mock(NativeEngineFieldVectorsWriter.class);
FlatFieldVectorsWriter flatFieldVectorsWriter = mock(FlatFieldVectorsWriter.class);
DocsWithFieldSet docsWithFieldSet = new DocsWithFieldSet();
vectors.keySet().stream().sorted().forEach(docsWithFieldSet::add);
when(fieldVectorsWriter.getFieldInfo()).thenReturn(fieldInfo);
when(fieldVectorsWriter.getVectors()).thenReturn(vectors);
when(fieldVectorsWriter.getDocsWithField()).thenReturn(docsWithFieldSet);
when(fieldVectorsWriter.getFlatFieldVectorsWriter()).thenReturn(flatFieldVectorsWriter);
when(flatFieldVectorsWriter.getDocsWithFieldSet()).thenReturn(docsWithFieldSet);
return fieldVectorsWriter;
}
}

0 comments on commit d58d133

Please sign in to comment.