Skip to content

Commit

Permalink
Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader (
Browse files Browse the repository at this point in the history
  • Loading branch information
bugmakerrrrrr authored and jpountz committed Oct 31, 2024
1 parent cff28d5 commit 584387a
Show file tree
Hide file tree
Showing 11 changed files with 218 additions and 205 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ Optimizations

* GITHUB#13958: Speed up advancing within a block. (Adrien Grand)

* GITHUB#13763: Replace Map<String,Object> with IntObjectHashMap for KnnVectorsReader (Pan Guixin)

Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.SplittableRandom;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
Expand All @@ -33,6 +31,7 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
import org.apache.lucene.store.ChecksumIndexInput;
Expand All @@ -50,14 +49,16 @@
*/
public final class Lucene90HnswVectorsReader extends KnnVectorsReader {

private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final long checksumSeed;
private final FieldInfos fieldInfos;

Lucene90HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
long[] checksumRef = new long[1];
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
Expand Down Expand Up @@ -158,7 +159,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce

FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -218,13 +219,18 @@ public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorIndex);
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return getOffHeapVectorValues(fieldEntry);
return fieldEntry;
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return getOffHeapVectorValues(getFieldEntry(field));
}

@Override
Expand All @@ -235,8 +241,7 @@ public ByteVectorValues getByteVectorValues(String field) {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.function.IntUnaryOperator;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
Expand All @@ -35,6 +33,7 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.search.VectorScorer;
Expand All @@ -55,13 +54,15 @@
*/
public final class Lucene91HnswVectorsReader extends KnnVectorsReader {

private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;

Lucene91HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
Expand Down Expand Up @@ -154,7 +155,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -214,13 +215,18 @@ public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorIndex);
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return getOffHeapVectorValues(fieldEntry);
return fieldEntry;
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return getOffHeapVectorValues(getFieldEntry(field));
}

@Override
Expand All @@ -231,8 +237,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
Expand All @@ -34,6 +32,7 @@
import org.apache.lucene.index.IndexFileNames;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand All @@ -53,13 +52,15 @@
*/
public final class Lucene92HnswVectorsReader extends KnnVectorsReader {

private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;

Lucene92HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
Expand Down Expand Up @@ -152,7 +153,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -212,13 +213,18 @@ public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorIndex);
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
return fieldEntry;
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
return OffHeapFloatVectorValues.load(getFieldEntry(field), vectorData);
}

@Override
Expand All @@ -229,8 +235,7 @@ public ByteVectorValues getByteVectorValues(String field) {
@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

final FieldEntry fieldEntry = getFieldEntry(field);
if (fieldEntry.size() == 0) {
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.codecs.CodecUtil;
import org.apache.lucene.codecs.KnnVectorsReader;
import org.apache.lucene.codecs.hnsw.DefaultFlatVectorScorer;
Expand All @@ -35,6 +33,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.VectorEncoding;
import org.apache.lucene.index.VectorSimilarityFunction;
import org.apache.lucene.internal.hppc.IntObjectHashMap;
import org.apache.lucene.search.KnnCollector;
import org.apache.lucene.store.ChecksumIndexInput;
import org.apache.lucene.store.DataInput;
Expand All @@ -54,13 +53,15 @@
*/
public final class Lucene94HnswVectorsReader extends KnnVectorsReader {

private final Map<String, FieldEntry> fields = new HashMap<>();
private final IntObjectHashMap<FieldEntry> fields = new IntObjectHashMap<>();
private final IndexInput vectorData;
private final IndexInput vectorIndex;
private final DefaultFlatVectorScorer defaultFlatVectorScorer = new DefaultFlatVectorScorer();
private final FieldInfos fieldInfos;

Lucene94HnswVectorsReader(SegmentReadState state) throws IOException {
int versionMeta = readMetadata(state);
this.fieldInfos = state.fieldInfos;
boolean success = false;
try {
vectorData =
Expand Down Expand Up @@ -153,7 +154,7 @@ private void readFields(ChecksumIndexInput meta, FieldInfos infos) throws IOExce
}
FieldEntry fieldEntry = readField(meta, info);
validateFieldEntry(info, fieldEntry);
fields.put(info.name, fieldEntry);
fields.put(info.number, fieldEntry);
}
}

Expand Down Expand Up @@ -230,48 +231,41 @@ public void checkIntegrity() throws IOException {
CodecUtil.checksumEntireFile(vectorIndex);
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
private FieldEntry getFieldEntry(String field, VectorEncoding expectedEncoding) {
final FieldInfo info = fieldInfos.fieldInfo(field);
final FieldEntry fieldEntry;
if (info == null || (fieldEntry = fields.get(info.number)) == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
if (fieldEntry.vectorEncoding != expectedEncoding) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.FLOAT32);
+ expectedEncoding);
}
return fieldEntry;
}

@Override
public FloatVectorValues getFloatVectorValues(String field) throws IOException {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
return OffHeapFloatVectorValues.load(fieldEntry, vectorData);
}

@Override
public ByteVectorValues getByteVectorValues(String field) throws IOException {
FieldEntry fieldEntry = fields.get(field);
if (fieldEntry == null) {
throw new IllegalArgumentException("field=\"" + field + "\" not found");
}
if (fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
throw new IllegalArgumentException(
"field=\""
+ field
+ "\" is encoded as: "
+ fieldEntry.vectorEncoding
+ " expected: "
+ VectorEncoding.BYTE);
}
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
return OffHeapByteVectorValues.load(fieldEntry, vectorData);
}

@Override
public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.FLOAT32) {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.FLOAT32);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
return;
}

Expand All @@ -289,9 +283,8 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits
@Override
public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs)
throws IOException {
FieldEntry fieldEntry = fields.get(field);

if (fieldEntry.size() == 0 || fieldEntry.vectorEncoding != VectorEncoding.BYTE) {
final FieldEntry fieldEntry = getFieldEntry(field, VectorEncoding.BYTE);
if (fieldEntry.size() == 0 || knnCollector.k() == 0) {
return;
}

Expand Down
Loading

0 comments on commit 584387a

Please sign in to comment.