Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader #13763

Merged
merged 7 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.SplittableRandom;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
Expand All @@ -34,6 +32,7 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
Expand All @@ -53,14 +52,16 @@
*/
public final class Lucene90HnswVectorsReader extends KnnVectorsReader {

private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final long checksumSeed;
private final FieldInfos fieldInfos;

Lucene90HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
long[] checksumRef = new long[1];
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
Expand Down Expand Up @@ -161,7 +162,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce

FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -223,7 +224,8 @@ public void checkIntegrity() throws IOException {

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couldn't final FieldInfo info be null? (we should check for null, see #13641 which should be merged soon).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is also what I want to discuss. You can ref to my comment below, if we all agree to check for null, I can fix this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my opinion, the most important thing is to remove the leniency we have in some places, like if (fieldEntry == null) return EMPTY_VECTOR_VALUES, which could hide bugs. Checking for null explicitly is not required, though it's obviously nice.

final FieldEntry fieldEntry = fields.get(info.number);
return getOffHeapVectorValues(fieldEntry);
}

Expand All @@ -235,7 +237,8 @@ public ByteVectorValues getByteVectorValues(String field) {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);

if (fieldEntry.size() == 0) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.IntUnaryOperator;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
Expand All @@ -35,6 +33,7 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
Expand All @@ -56,13 +55,15 @@
*/
public final class Lucene91HnswVectorsReader extends KnnVectorsReader {

private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;

Lucene91HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
Expand Down Expand Up @@ -155,7 +156,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -217,7 +218,8 @@ public void checkIntegrity() throws IOException {

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);
return getOffHeapVectorValues(fieldEntry);
}

Expand All @@ -229,7 +231,8 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);

if (fieldEntry.size() == 0) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
Expand All @@ -34,6 +32,7 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand All @@ -53,13 +52,15 @@
*/
public final class Lucene92HnswVectorsReader extends KnnVectorsReader {

private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;

Lucene92HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
Expand Down Expand Up @@ -152,7 +153,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -214,7 +215,8 @@ public void checkIntegrity() throws IOException {

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
}

Expand All @@ -226,7 +228,8 @@ public ByteVectorValues getByteVectorValues(String field) {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);

if (fieldEntry.size() == 0) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
Expand All @@ -35,6 +33,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand All @@ -54,13 +53,15 @@
*/
public final class Lucene94HnswVectorsReader extends KnnVectorsReader {

private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;

Lucene94HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
Expand Down Expand Up @@ -153,7 +154,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -232,7 +233,8 @@ public void checkIntegrity() throws IOException {

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"field=\""
Expand All @@ -247,7 +249,8 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException {

@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
Expand All @@ -263,7 +266,8 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);

if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
return;
Expand All @@ -283,7 +287,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);

if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
Expand All @@ -39,6 +37,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand All @@ -61,7 +60,7 @@
public final class Lucene95HnswVectorsReader extends KnnVectorsReader implements HnswGraphProvider {

private final FieldInfos fieldInfos;
private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
Expand Down Expand Up @@ -161,7 +160,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -240,7 +239,8 @@ public void checkIntegrity() throws IOException {

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
throw new IllegalArgumentException(
"field=\""
Expand All @@ -263,7 +263,8 @@ public FloatVectorValues getFloatVectorValues(String field) throws IOException {

@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
Expand All @@ -287,7 +288,8 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);

if (fieldEntry.size() == 0
|| knnCollector.k() == 0
Expand Down Expand Up @@ -318,7 +320,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry = fields.get(info.number);

if (fieldEntry.size() == 0
|| knnCollector.k() == 0
Expand Down Expand Up @@ -349,11 +352,11 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits
/** Get knn graph values; used for testing */
@Override
public HnswGraph getGraph(String field) throws IOException {
FieldInfo info = fieldInfos.fieldInfo(field);
final FieldInfo info = fieldInfos.fieldInfo(field);
if (info == null) {
throw new IllegalArgumentException("No such field '" + field + "'");
}
FieldEntry entry = fields.get(field);
final FieldEntry entry = fields.get(info.number);
if (entry != null && entry.vectorIndexLength > 0) {
return getGraph(entry);
} else {
Expand Down
Loading