From 60e3fabb2f6117add792ffdb8938248a7c3028c4 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Tue, 6 Aug 2024 11:26:46 +0100 Subject: [PATCH 01/41] implement seeded knn queries --- .../.gradle-wrapper8297076614091580458.tmp | 0 hsperfdata_root/77 | Bin 0 -> 32768 bytes .../lucene90/Lucene90HnswVectorsReader.java | 14 +- .../lucene91/Lucene91HnswVectorsReader.java | 17 ++- .../lucene92/Lucene92HnswVectorsReader.java | 15 +- .../lucene94/Lucene94HnswVectorsReader.java | 21 ++- .../lucene95/Lucene95HnswVectorsReader.java | 21 ++- .../SimpleTextKnnVectorsReader.java | 14 +- .../bitvectors/TestHnswBitVectorsFormat.java | 2 +- .../lucene/codecs/KnnVectorsFormat.java | 13 +- .../lucene/codecs/KnnVectorsReader.java | 19 ++- .../lucene/codecs/hnsw/FlatVectorsReader.java | 15 +- .../lucene99/Lucene99HnswVectorsReader.java | 15 +- .../perfield/PerFieldKnnVectorsFormat.java | 19 ++- .../org/apache/lucene/index/CheckIndex.java | 6 +- .../org/apache/lucene/index/CodecReader.java | 19 ++- .../lucene/index/DocValuesLeafReader.java | 15 +- .../lucene/index/ExitableDirectoryReader.java | 16 +- .../apache/lucene/index/FilterLeafReader.java | 19 ++- .../org/apache/lucene/index/LeafReader.java | 39 ++++- .../lucene/index/ParallelLeafReader.java | 17 ++- .../lucene/index/SlowCodecReaderWrapper.java | 19 ++- .../SlowCompositeCodecReaderWrapper.java | 14 +- .../lucene/index/SortingCodecReader.java | 15 +- .../lucene/search/AbstractKnnVectorQuery.java | 134 ++++++++++++++++- .../search/ByteVectorSimilarityQuery.java | 2 +- .../search/FloatVectorSimilarityQuery.java | 2 +- .../lucene/search/KnnByteVectorQuery.java | 24 ++- .../lucene/search/KnnFloatVectorQuery.java | 24 ++- .../lucene/util/hnsw/HnswGraphSearcher.java | 39 +++++ ...estLucene99HnswQuantizedVectorsFormat.java | 2 +- ...stLucene99ScalarQuantizedVectorScorer.java | 2 +- ...tLucene99ScalarQuantizedVectorsFormat.java | 2 +- .../TestPerFieldKnnVectorsFormat.java | 7 +- .../lucene/document/TestManyKnnDocs.java | 138 ++++++++++++++++-- .../index/TestExitableDirectoryReader.java | 4 + .../org/apache/lucene/index/TestKnnGraph.java | 2 +- .../index/TestSegmentToThreadMapping.java | 13 +- .../search/BaseKnnVectorQueryTestCase.java | 90 ++++++++++++ .../lucene/search/TestKnnByteVectorQuery.java | 18 ++- .../search/TestKnnFloatVectorQuery.java | 18 ++- .../highlight/TermVectorLeafReader.java | 13 +- ...iversifyingChildrenByteKnnVectorQuery.java | 3 +- ...versifyingChildrenFloatKnnVectorQuery.java | 3 +- .../lucene/index/memory/MemoryIndex.java | 13 +- .../asserting/AssertingKnnVectorsFormat.java | 19 ++- .../index/BaseKnnVectorsFormatTestCase.java | 28 +++- .../tests/index/MergeReaderWrapper.java | 19 ++- .../tests/index/MismatchedLeafReader.java | 19 ++- .../lucene/tests/search/QueryUtils.java | 12 +- 50 files changed, 886 insertions(+), 128 deletions(-) create mode 100644 gradle/wrapper/.gradle-wrapper8297076614091580458.tmp create mode 100644 hsperfdata_root/77 diff --git a/gradle/wrapper/.gradle-wrapper8297076614091580458.tmp b/gradle/wrapper/.gradle-wrapper8297076614091580458.tmp new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/hsperfdata_root/77 b/hsperfdata_root/77 new file mode 100644 index 0000000000000000000000000000000000000000..784b5f536d259468dc60b40f01c4e4dd04441db9 GIT binary patch literal 32768 zcmeHOYm6nwRqi{p4jZxpn*=Zpfwr9x$Lr~tJGF#@X#scBUI8umEERdX+ga848KM=?tV+r9O2wRR2L_&%skU%mIEBt^TAjlu! zJEyB3)wl0sypbiOYHdHJx4v^ueRb-b>Z-c0|IaJ0wC0S~$&VR)y!HHj-%AfB&cDNX z0q6U0(y?Wnm|Ma<8H84xT2CiOq5Jfea4&FDKME5r^De&5fLj*|PQnZu4^@4nU(vS~ zg{cTV;cZ9m!AhEnUO$EEcS`|N?Y%rm{XjhF9Ed?5xR%6Gy{dX2s_Wfx;sca_U{*bk zN2$N(yE!$v?3#V^BYa+s!i~e!2~(-4QD3Wyb7SG8;zn;(bp7zU-xCD+;;T{|GJm<{3>_=ohnZetC8-mq!pm#-6Y@df zztx&2{uNx{52<1wlI4|pp*-P7%I^s0K$f3bkLQS96dHFm=`rT3{+~E|q96HT+Ce|a zskrcRW}#9a^qY!56*%fKA3TPhTn9l;eg2o1mQqh+uJB7LpI0+OWIC3b_^IKl)y%P9 zN*`eXlk{b{ym~oO>cy?#)tKXcO@2dm_C&fLd0F}Q|NL_dPZ#0{@vxU);*VU#(qKT)?I(-Pic<%*`g^hD5DjQHHwxCD(GY9|Tc!Fz9!V5>{C> z4bYQunO`%A`aT2y`3vkmH9J+mOQfa(1U>t&zsy4;l}{CifUPP6?zeyIS~TFH$LBsmS|{`4i9J6QE7IyL5knr&hwq;bSTskbXa5%YN2TCze1CEi^av&BxRS41TH97Os0Yy=W#iqX`cRIlH4KR<{E7TzITg@4eBS## zq`y{=DSWmO3>4>tgs+j9nf)u{vu)*k8qfYP<7;+I;j_&-|IYbbmVCDBJWSEQ{9WR6 zxq;7@-Ugq`C7+BQY1gUsndYC9wrZ*mV&OVn6I& z$6Qph&#nLcQg01ksSW^JEs3M;60J}m__Ie;-No`;toEWSoceORrQ}}z;}6$ zY1{Al>BgZekmHL^4+(iYkudU?%J@jyPlJL&DN1Gecsi~hN4=;5{CBSMblefXB)XMU z(dXu|p}A})#jmBE&aj(tB)pYj7sdGqyleu)#$Pkz*TpbH8^(*9G-Fl933{&et5N+s zjt7?6iuT9;h#*Crq311QzM-99|5ZPX;vE`77ayDR1i)eAEbX6!zk+%0c<2O~h(J2p zbUY$-T{(2-`qg%uZ}1bwZ@Ozyk^-0e_YAl}>Yz7E+%B+msJJG-Ot?c7&Me0z({b#V z#Z2x!p<XgIA zU#0#gd>LoBCY>PaBcc7v|7OPwc;^;H3mybutwG{*BWA!;oEwXu9zT%&cdDsv*1QJX zMm3ernn!`FWJApOE8|wG_t<>Ib|TbqCdwH4;v4t#gNM^_?|G+Ys$B@klLYHGKQGFhwr8bRbLZsFBTP?VZ5ZwPmzqV z{=3Rysve_QuMJ$uQ>TTi+ z{na>$YV%>^I=TJ{zN-W>Vduc__tn7zu+#B-i0_vXW~18)eWH;l;POLl_i}##=AYBR z*|jiZ9l7+qgtdFZaL~ei31?|{wq7suH{plrj62|%OUNvvbwQT&K^k+Bf^)X+7 ziD;J=lqXlcCa0PY8|5ezK&N4RmaoT^(?E$$g6J}a&qjN4YMy=bGX$^UE9+an6@pM& zzaO{pYpjGvxAz4tkFCq~Xwo;P-VL};Drhqrj=%Y{ zrnodt@0EO+$9U?!3^#UgBfa+)eR0wYY|t zSdUo682g}u$x~t8j(vg;a@0x3OdPOAH=!Zx~p{wcaHN|4nQJFk1YOc z*ux4pCOGpQ_}MDu7B-@rH~csqIDvVU`<^g6)pA9BXwxOhr&XsW`ErAF`#c6e7p!)X zCn-)6Zg~u@XJHwD_g3f7)8wMXJT`3>ous#y)dWd1+**IkJa=|s)lbb1R*FPy-nd%G zxkt(&;Z63F9ibP+#)cdC{Y026Uer(hK@a8W$lK>6p2;4H3pF|3f#r$aLDz^yATYy2 zutGomh=j&X^!-o$uV^0nV;A$)R^ac(PJCqbopj$=B-x92MKb8f$M5uUM*3~sYx?O| z1taZTvI_u?n>fi|Xi!0NehT;5sBWc2() z;h8<8l1EG-zcd!WJeGAHF=?MEvH;8N<4X1tZzT6D?hW>;>F|7z9u5!B_!pe;Y=8H{xg$|2NT%g~2F@TvmW)XN`f&`KazCWavt(a` zACe#JN7ydN9H1N>G@he+Q9LBv2XSW}5nP77%YoC)hC(;tU{vm-xKp_Ym2xXFiY?r{ zdGWfBuR9MZx{s+yF%-e2wn^_#Kaq%h_`kWvuU6a?GnSZg$3Db_Rw%HcmP#*RF?m58 z!jIDA(k*lz`u;|CjTiA?QMX8=2)nAuO4oAw{-u@e?S&^Ee|+WY#`qZ{J(WNIk z3oD(qD_8Izm<@;_1fCop;eGyw_!%JHToPgaYKTxqKVJJ3`VsZ3b2upv|!HhBcS&OgjD5gH(pcR!N+&q(>Fai=~6`S}?8 zhD@~Ma;5oDTW0Um?T?syQEcLe34_B0r{|qJYi@mVLn-ixxjXQK)Wqz_%Ajj5pEZ}4 zmd}_=OBeZ{xqDq6{0zWXn} zMBn^*{-bx;?>ukgPW9tJ%)era5ss-v*O7g@yluMtYBh{+{l~`Lgw5k8J(C==rY?sj zq}aT{&HQUs=vTh6hHjD@o~bBvwsi=tcrh36T|Q zl0)`VJJUh-Q(PzepYi*2V$C4?ru_oqHyz{ZTPbf*$IVs!`9H-Hnc?+SGoTsJ3}^;4 z1DXNNfM!55pc&8%Xa+O`ngPv#WGy|Fe&46Y= zGoTsJ3}^;41DXNNfM!55pc&8%Xa+O`ngPv#W VGy|Fe&46Y=GoTsJ47^he{2!^$TUh`A literal 0 HcmV?d00001 diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index ab2486f4518b..ea49f49e7c44 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -233,7 +233,12 @@ public ByteVectorValues getByteVectorValues(String field) { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -265,7 +270,12 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 048280466d43..3ccdb0e78361 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -227,7 +227,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -243,11 +248,17 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - getAcceptOrds(acceptDocs, fieldEntry)); + getAcceptOrds(acceptDocs, fieldEntry), + seedDocs); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocsBits, + DocIdSetIterator seedDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java index 833efdf80259..8a690d291887 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java @@ -34,6 +34,7 @@ import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -224,7 +225,12 @@ public ByteVectorValues getByteVectorValues(String field) { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -244,7 +250,12 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java index a948ab7bee3f..9517edbb1b20 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java @@ -35,6 +35,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -261,7 +262,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -277,11 +283,17 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs)); + vectorValues.getAcceptOrds(acceptDocs), + seedDocs); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -297,7 +309,8 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs)); + vectorValues.getAcceptOrds(acceptDocs), + seedDocs); } private HnswGraph getGraph(FieldEntry entry) throws IOException { diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java index 1b74ff94c18c..da4ef3182bea 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java @@ -39,6 +39,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -285,7 +286,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -312,11 +318,17 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs)); + vectorValues.getAcceptOrds(acceptDocs), + seedDocs); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -343,7 +355,8 @@ public void search(String field, byte[] target, KnnCollector knnCollector, Bits scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs)); + vectorValues.getAcceptOrds(acceptDocs), + seedDocs); } /** Get knn graph values; used for testing */ diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index faba629715b7..74f3a15b4db6 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -180,7 +180,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { FloatVectorValues values = getFloatVectorValues(field); if (target.length != values.dimension()) { @@ -210,7 +215,12 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { ByteVectorValues values = getByteVectorValues(field); if (target.length != values.dimension()) { diff --git a/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java b/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java index ab20ee67c8c9..ec2bdecb1eb9 100644 --- a/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java +++ b/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java @@ -89,7 +89,7 @@ public void testIndexAndSearchBitVectors() throws IOException { try (IndexReader reader = DirectoryReader.open(w)) { LeafReader r = getOnlyLeafReader(reader); TopKnnCollector collector = new TopKnnCollector(3, Integer.MAX_VALUE); - r.searchNearestVectors("v1", vectors[0], collector, null); + r.searchNearestVectors("v1", vectors[0], collector, null, null); TopDocs topDocs = collector.topDocs(); assertEquals(3, topDocs.scoreDocs.length); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java index ad6e4aba607c..51dd9a83f424 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java @@ -23,6 +23,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; import org.apache.lucene.util.NamedSPILoader; @@ -138,13 +139,21 @@ public ByteVectorValues getByteVectorValues(String field) { @Override public void search( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) { + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) { throw new UnsupportedOperationException(); } @Override public void search( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) { + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java index e054ebeb2bb1..375f547e5182 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java @@ -22,6 +22,7 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -82,9 +83,16 @@ protected KnnVectorsReader() {} * @param knnCollector a KnnResults collector and relevant settings for gathering vector results * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. + * @param seedDocs {@link Bits} that represents the documents used to seed the search, or {@code + * null} to perform a search without seeds. */ public abstract void search( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException; /** * Return the k nearest neighbor documents as determined by comparison of their vector values for @@ -110,9 +118,16 @@ public abstract void search( * @param knnCollector a KnnResults collector and relevant settings for gathering vector results * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. + * @param seedDocs {@link Bits} that represents the documents used to seed the search, or {@code + * null} to perform a search without seeds. */ public abstract void search( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException; /** * Returns an instance optimized for merging. This instance may only be consumed in the thread diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java index 9d776567883e..09243897efd5 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java @@ -19,6 +19,7 @@ import java.io.IOException; import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Bits; @@ -56,13 +57,23 @@ public FlatVectorsScorer getFlatVectorScorer() { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { // don't scan stored field data. If we didn't index it, produce no search results } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDoc, + DocIdSetIterator seedDocs) throws IOException { // don't scan stored field data. If we didn't index it, produce no search results } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index 35bc38571a6a..7795517776a0 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -37,6 +37,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -247,7 +248,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seeds) throws IOException { search( fields.get(field), @@ -258,7 +264,12 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seeds) throws IOException { search( fields.get(field), diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java index e665528652ca..2d938128a4e8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java @@ -33,6 +33,7 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; @@ -270,15 +271,25 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { - fields.get(field).search(field, target, knnCollector, acceptDocs); + fields.get(field).search(field, target, knnCollector, acceptDocs, seedDocs); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { - fields.get(field).search(field, target, knnCollector, acceptDocs); + fields.get(field).search(field, target, knnCollector, acceptDocs, seedDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java index 65ac2fcd2607..159e65938ce5 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java +++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java @@ -2769,7 +2769,7 @@ private static void checkFloatVectorValues( if (vectorsReaderSupportsSearch(codecReader, fieldInfo.name)) { codecReader .getVectorReader() - .search(fieldInfo.name, values.vectorValue(), collector, null); + .search(fieldInfo.name, values.vectorValue(), collector, null, null); TopDocs docs = collector.topDocs(); if (docs.scoreDocs.length == 0) { throw new CheckIndexException( @@ -2815,7 +2815,9 @@ private static void checkByteVectorValues( // search the first maxNumSearches vectors to exercise the graph if (supportsSearch && values.docID() % everyNdoc == 0) { KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE); - codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null); + codecReader + .getVectorReader() + .search(fieldInfo.name, values.vectorValue(), collector, null, null); TopDocs docs = collector.topDocs(); if (docs.scoreDocs.length == 0) { throw new CheckIndexException( diff --git a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java index bec27c5176e8..7bd13e20551e 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java @@ -25,6 +25,7 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -260,7 +261,12 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti @Override public final void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException { ensureOpen(); FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null @@ -269,12 +275,17 @@ public final void searchNearestVectors( // Field does not exist or does not index vectors return; } - getVectorReader().search(field, target, knnCollector, acceptDocs); + getVectorReader().search(field, target, knnCollector, acceptDocs, seedDocs); } @Override public final void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException { ensureOpen(); FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null @@ -283,7 +294,7 @@ public final void searchNearestVectors( // Field does not exist or does not index vectors return; } - getVectorReader().search(field, target, knnCollector, acceptDocs); + getVectorReader().search(field, target, knnCollector, acceptDocs, seedDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java index 3504c7429a5e..717c0e64d0ba 100644 --- a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java @@ -18,6 +18,7 @@ package org.apache.lucene.index; import java.io.IOException; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -59,13 +60,23 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException { throw new UnsupportedOperationException(); } @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index ca2cb1a27d45..53024ba41658 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -333,7 +333,11 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would @@ -361,12 +365,16 @@ public int length() { } }; - in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs); + in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs, seedDocs); } @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would // match all docs to allow timeout checking. @@ -393,7 +401,7 @@ public int length() { } }; - in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs); + in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs, seedDocs); } private void checkAndThrowForSearchVectors() { diff --git a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java index 87d62f22d041..c6996cb96259 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.Iterator; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.AttributeSource; import org.apache.lucene.util.Bits; @@ -365,14 +366,24 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - in.searchNearestVectors(field, target, knnCollector, acceptDocs); + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); } @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - in.searchNearestVectors(field, target, knnCollector, acceptDocs); + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java index 0f39d1ae1e8d..0988c7d289af 100644 --- a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java @@ -17,6 +17,7 @@ package org.apache.lucene.index; import java.io.IOException; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -251,7 +252,13 @@ public final PostingsEnum postings(Term term) throws IOException { * @lucene.experimental */ public final TopDocs searchNearestVectors( - String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { + String field, + float[] target, + int k, + Bits acceptDocs, + DocIdSetIterator seedDocs, + int visitedLimit) + throws IOException { FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null || fi.getVectorDimension() == 0) { return TopDocsCollector.EMPTY_TOPDOCS; @@ -265,7 +272,7 @@ public final TopDocs searchNearestVectors( return TopDocsCollector.EMPTY_TOPDOCS; } KnnCollector collector = new TopKnnCollector(k, visitedLimit); - searchNearestVectors(field, target, collector, acceptDocs); + searchNearestVectors(field, target, collector, acceptDocs, seedDocs); return collector.topDocs(); } @@ -295,7 +302,13 @@ public final TopDocs searchNearestVectors( * @lucene.experimental */ public final TopDocs searchNearestVectors( - String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { + String field, + byte[] target, + int k, + Bits acceptDocs, + DocIdSetIterator seedDocs, + int visitedLimit) + throws IOException { FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null || fi.getVectorDimension() == 0) { return TopDocsCollector.EMPTY_TOPDOCS; @@ -309,7 +322,7 @@ public final TopDocs searchNearestVectors( return TopDocsCollector.EMPTY_TOPDOCS; } KnnCollector collector = new TopKnnCollector(k, visitedLimit); - searchNearestVectors(field, target, collector, acceptDocs); + searchNearestVectors(field, target, collector, acceptDocs, seedDocs); return collector.topDocs(); } @@ -337,10 +350,17 @@ public final TopDocs searchNearestVectors( * @param knnCollector collector with settings for gathering the vector results. * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. + * @param seedDocs {@link Bits} that represents an initial set of documents to seed the search, or + * {@code null} if a full search is to be conducted. * @lucene.experimental */ public abstract void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException; /** * Return the k nearest neighbor documents as determined by comparison of their vector values for @@ -366,10 +386,17 @@ public abstract void searchNearestVectors( * @param knnCollector collector with settings for gathering the vector results. * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. + * @param seedDocs {@link Bits} that represents an initial set of documents to seed the search, or + * {@code null} if a full search is to be conducted. * @lucene.experimental */ public abstract void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException; /** * Get the {@link FieldInfos} describing all fields in this reader. diff --git a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java index 1f1e2dba9c12..09b2d665381a 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java @@ -26,6 +26,7 @@ import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; import org.apache.lucene.util.Bits; @@ -437,23 +438,31 @@ public ByteVectorValues getByteVectorValues(String fieldName) throws IOException @Override public void searchNearestVectors( - String fieldName, float[] target, KnnCollector knnCollector, Bits acceptDocs) + String fieldName, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { ensureOpen(); LeafReader reader = fieldToReader.get(fieldName); if (reader != null) { - reader.searchNearestVectors(fieldName, target, knnCollector, acceptDocs); + reader.searchNearestVectors(fieldName, target, knnCollector, acceptDocs, seedDocs); } } @Override public void searchNearestVectors( - String fieldName, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + String fieldName, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { ensureOpen(); LeafReader reader = fieldToReader.get(fieldName); if (reader != null) { - reader.searchNearestVectors(fieldName, target, knnCollector, acceptDocs); + reader.searchNearestVectors(fieldName, target, knnCollector, acceptDocs, seedDocs); } } diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java index 4d05d241e699..79ff6acb0d31 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java @@ -28,6 +28,7 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -173,15 +174,25 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { - reader.searchNearestVectors(field, target, knnCollector, acceptDocs); + reader.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { - reader.searchNearestVectors(field, target, knnCollector, acceptDocs); + reader.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index fc6c1d9b2941..e72cde5c74b7 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -963,13 +963,23 @@ public VectorScorer scorer(byte[] target) { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seeds) throws IOException { throw new UnsupportedOperationException(); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seeds) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index ff88e30de4a5..1868a6fc6f9c 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -32,6 +32,7 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; @@ -523,12 +524,22 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) { + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) { throw new UnsupportedOperationException(); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) { + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index c0ce4eea3c6b..db939db0d324 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -21,6 +21,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Objects; @@ -48,6 +49,10 @@ *
  • Otherwise run a kNN search subject to the filter *
  • If the kNN search visits too many vectors without completing, stop and run an exact search * + * + *

    When a seed query is provided, this query is executed first to seed the kNN search (subject to + * the same rules about the filter). If the seed query fails to identify any documents, it falls + * back on the strategy above. */ abstract class AbstractKnnVectorQuery extends Query { @@ -56,14 +61,20 @@ abstract class AbstractKnnVectorQuery extends Query { protected final String field; protected final int k; private final Query filter; + private final Query seed; public AbstractKnnVectorQuery(String field, int k, Query filter) { + this(field, k, filter, null); + } + + public AbstractKnnVectorQuery(String field, int k, Query filter, Query seed) { this.field = Objects.requireNonNull(field, "field"); this.k = k; if (k < 1) { throw new IllegalArgumentException("k must be at least 1, got: " + k); } this.filter = filter; + this.seed = seed; } @Override @@ -83,6 +94,14 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { filterWeight = null; } + final Weight seedWeight; + if (seed != null) { + Query seedRewritten = indexSearcher.rewrite(seed); + seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + } else { + seedWeight = null; + } + TimeLimitingKnnCollectorManager knnCollectorManager = new TimeLimitingKnnCollectorManager( getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout()); @@ -90,7 +109,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { List leafReaderContexts = reader.leaves(); List> tasks = new ArrayList<>(leafReaderContexts.size()); for (LeafReaderContext context : leafReaderContexts) { - tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager)); + tasks.add(() -> searchLeaf(context, filterWeight, seedWeight, knnCollectorManager)); } TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new); @@ -105,9 +124,11 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { private TopDocs searchLeaf( LeafReaderContext ctx, Weight filterWeight, + Weight seedWeight, TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) throws IOException { - TopDocs results = getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager); + TopDocs results = + getLeafResults(ctx, filterWeight, seedWeight, timeLimitingKnnCollectorManager); if (ctx.docBase > 0) { for (ScoreDoc scoreDoc : results.scoreDocs) { scoreDoc.doc += ctx.docBase; @@ -119,13 +140,19 @@ private TopDocs searchLeaf( private TopDocs getLeafResults( LeafReaderContext ctx, Weight filterWeight, + Weight seedWeight, TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) throws IOException { final LeafReader reader = ctx.reader(); final Bits liveDocs = reader.getLiveDocs(); if (filterWeight == null) { - return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager); + return approximateSearch( + ctx, + liveDocs, + executeSeedQuery(ctx, seedWeight), + Integer.MAX_VALUE, + timeLimitingKnnCollectorManager); } Scorer scorer = filterWeight.scorer(ctx); @@ -145,7 +172,13 @@ private TopDocs getLeafResults( // Perform the approximate kNN search // We pass cost + 1 here to account for the edge case when we explore exactly cost vectors - TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager); + TopDocs results = + approximateSearch( + ctx, + acceptDocs, + executeSeedQuery(ctx, seedWeight), + cost + 1, + timeLimitingKnnCollectorManager); if (results.totalHits.relation == TotalHits.Relation.EQUAL_TO // Return partial results only when timeout is met || (queryTimeout != null && queryTimeout.shouldExit())) { @@ -156,6 +189,44 @@ private TopDocs getLeafResults( } } + private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeight) + throws IOException { + if (seedWeight != null) { + // Execute the seed query + TopScoreDocCollector seedCollector = + new TopScoreDocCollectorManager(k, Integer.MAX_VALUE).newCollector(); + LeafCollector leafCollector; + try { + leafCollector = seedCollector.getLeafCollector(ctx); + } catch ( + @SuppressWarnings("unused") + CollectionTerminatedException e) { + // there is no doc of interest in this reader context + // continue with the following leaf + leafCollector = null; + } + if (leafCollector != null) { + BulkScorer scorer = seedWeight.bulkScorer(ctx); + if (scorer != null) { + try { + scorer.score(leafCollector, ctx.reader().getLiveDocs()); + } catch ( + @SuppressWarnings("unused") + CollectionTerminatedException e) { + // collection was terminated prematurely + // continue with the following leaf + } + } + leafCollector.finish(); + } + + TopDocs seedTopDocs = seedCollector.topDocs(); + return new TopDocsDISI(seedTopDocs); + } else { + return null; + } + } + private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException { if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) { @@ -181,6 +252,7 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search protected abstract TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, + DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException; @@ -302,12 +374,15 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; AbstractKnnVectorQuery that = (AbstractKnnVectorQuery) o; - return k == that.k && Objects.equals(field, that.field) && Objects.equals(filter, that.filter); + return k == that.k + && Objects.equals(field, that.field) + && Objects.equals(filter, that.filter) + && Objects.equals(seed, that.seed); } @Override public int hashCode() { - return Objects.hash(field, k, filter); + return Objects.hash(field, k, filter, seed); } /** @@ -332,6 +407,13 @@ public Query getFilter() { return filter; } + /** + * @return the query that seeds the kNN search. + */ + public Query getSeed() { + return seed; + } + /** Caches the results of a KnnVector search: a list of docs and their scores */ static class DocAndScoreQuery extends Query { @@ -491,4 +573,44 @@ public int hashCode() { classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores)); } } + + private class TopDocsDISI extends DocIdSetIterator { + int idx = -1; + List sortedDocIdList; + + public TopDocsDISI(TopDocs topDocs) { + sortedDocIdList = new ArrayList(); + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + sortedDocIdList.add(topDocs.scoreDocs[i].doc); + } + Collections.sort(sortedDocIdList); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return sortedDocIdList.size(); + } + + @Override + public int docID() { + if (idx == -1) { + return -1; + } else if (idx >= sortedDocIdList.size()) { + return DocIdSetIterator.NO_MORE_DOCS; + } else { + return sortedDocIdList.get(idx); + } + } + + @Override + public int nextDoc() throws IOException { + idx += 1; + return docID(); + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java index bd2190121abc..80786131309c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java @@ -110,7 +110,7 @@ protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, throws IOException { KnnCollector collector = new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit); - context.reader().searchNearestVectors(field, target, collector, acceptDocs); + context.reader().searchNearestVectors(field, target, collector, acceptDocs, null); return collector.topDocs(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java index 3dc92482a77d..726a34aec1a2 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java @@ -112,7 +112,7 @@ protected TopDocs approximateSearch(LeafReaderContext context, Bits acceptDocs, throws IOException { KnnCollector collector = new VectorSimilarityCollector(traversalSimilarity, resultSimilarity, visitLimit); - context.reader().searchNearestVectors(field, target, collector, acceptDocs); + context.reader().searchNearestVectors(field, target, collector, acceptDocs, null); return collector.topDocs(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index db5ae4a0d9d2..8e993cd93271 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -30,8 +30,8 @@ import org.apache.lucene.util.Bits; /** - * Uses {@link KnnVectorsReader#search(String, byte[], KnnCollector, Bits)} to perform nearest - * neighbour search. + * Uses {@link KnnVectorsReader#search(String, byte[], KnnCollector, Bits, DocIdSetIterator)} to + * perform nearest neighbour search. * *

    This query also allows for performing a kNN search subject to a filter. In this case, it first * executes the filter for each leaf, then chooses a strategy dynamically: @@ -72,7 +72,22 @@ public KnnByteVectorQuery(String field, byte[] target, int k) { * @throws IllegalArgumentException if k is less than 1 */ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) { - super(field, k, filter); + this(field, target, k, filter, null); + } + + /** + * Find the k nearest documents to the target vector according to the vectors in the + * given field. target vector. + * + * @param field a field that has been indexed as a {@link KnnFloatVectorField}. + * @param target the target of the search + * @param k the number of documents to find + * @param filter a filter applied before the vector search + * @param seed a query that is executed to seed the vector search + * @throws IllegalArgumentException if k is less than 1 + */ + public KnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { + super(field, k, filter, seed); this.target = Objects.requireNonNull(target, "target"); } @@ -80,6 +95,7 @@ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) { protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, + DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { @@ -93,7 +109,7 @@ protected TopDocs approximateSearch( if (Math.min(knnCollector.k(), byteVectorValues.size()) == 0) { return NO_RESULTS; } - reader.searchNearestVectors(field, target, knnCollector, acceptDocs); + reader.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); TopDocs results = knnCollector.topDocs(); return results != null ? results : NO_RESULTS; } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 585893fa3c2a..a1f6a680ea0a 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -31,8 +31,8 @@ import org.apache.lucene.util.VectorUtil; /** - * Uses {@link KnnVectorsReader#search(String, float[], KnnCollector, Bits)} to perform nearest - * neighbour search. + * Uses {@link KnnVectorsReader#search(String, float[], KnnCollector, Bits, DocIdSetIterator)} to + * perform nearest neighbour search. * *

    This query also allows for performing a kNN search subject to a filter. In this case, it first * executes the filter for each leaf, then chooses a strategy dynamically: @@ -73,7 +73,22 @@ public KnnFloatVectorQuery(String field, float[] target, int k) { * @throws IllegalArgumentException if k is less than 1 */ public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) { - super(field, k, filter); + this(field, target, k, filter, null); + } + + /** + * Find the k nearest documents to the target vector according to the vectors in the + * given field. target vector. + * + * @param field a field that has been indexed as a {@link KnnFloatVectorField}. + * @param target the target of the search + * @param k the number of documents to find + * @param filter a filter applied before the vector search + * @param seed a query that is executed to seed the vector search + * @throws IllegalArgumentException if k is less than 1 + */ + public KnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) { + super(field, k, filter, seed); this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target")); } @@ -81,6 +96,7 @@ public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) { protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, + DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { @@ -94,7 +110,7 @@ protected TopDocs approximateSearch( if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) { return NO_RESULTS; } - reader.searchNearestVectors(field, target, knnCollector, acceptDocs); + reader.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); TopDocs results = knnCollector.topDocs(); return results != null ? results : NO_RESULTS; } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 46d6c93d52c3..46df6d0ba4ad 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -20,6 +20,8 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; +import java.util.ArrayList; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; import org.apache.lucene.util.BitSet; @@ -70,6 +72,43 @@ public static void search( search(scorer, knnCollector, graph, graphSearcher, acceptOrds); } + /** + * Searches the HNSW graph for for the nerest neighbors of a query vector, starting from the + * provided entry points. + * + * @param scorer the scorer to compare the query with the nodes + * @param knnCollector a collector of top knn results to be returned + * @param graph the graph values. May represent the entire graph, or a level in a hierarchical + * graph. + * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or + * {@code null} if they are all allowed to match. + * @param entryPointOrds the entry points for search. + */ + public static void search( + RandomVectorScorer scorer, + KnnCollector knnCollector, + HnswGraph graph, + Bits acceptOrds, + DocIdSetIterator entryPointOrds) + throws IOException { + ArrayList entryPointOrdInts = new ArrayList(); + if (entryPointOrds != null) { + int entryPointOrdInt; + while ((entryPointOrdInt = entryPointOrds.nextDoc()) != NO_MORE_DOCS) { + entryPointOrdInts.add(entryPointOrdInt); + } + } + if (entryPointOrdInts.size() == 0) { + search(scorer, knnCollector, graph, acceptOrds); + } else { + HnswGraphSearcher graphSearcher = + new HnswGraphSearcher( + new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(graph.size())); + int[] entryPointOrdIntsArr = entryPointOrdInts.stream().mapToInt(Integer::intValue).toArray(); + graphSearcher.searchLevel(knnCollector, scorer, 0, entryPointOrdIntsArr, graph, acceptOrds); + } + } + /** * Search {@link OnHeapHnswGraph}, this method is thread safe. * diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index 8e69e833b989..c51f8270940b 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -113,7 +113,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { try (IndexReader reader = DirectoryReader.open(w)) { LeafReader r = getOnlyLeafReader(reader); TopKnnCollector topKnnCollector = new TopKnnCollector(5, Integer.MAX_VALUE); - r.searchNearestVectors("f", new float[] {0.6f, 0.8f}, topKnnCollector, null); + r.searchNearestVectors("f", new float[] {0.6f, 0.8f}, topKnnCollector, null, null); TopDocs topDocs = topKnnCollector.topDocs(); assertEquals(3, topDocs.totalHits.value); for (ScoreDoc scoreDoc : topDocs.scoreDocs) { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java index 58e6c27e326a..966b30105c50 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java @@ -299,7 +299,7 @@ private void testSingleVectorPerSegment(VectorSimilarityFunction sim) throws IOE LeafReader leafReader = getOnlyLeafReader(reader); StoredFields storedFields = reader.storedFields(); float[] queryVector = new float[] {0.6f, 0.8f}; - var hits = leafReader.searchNearestVectors("field", queryVector, 3, null, 100); + var hits = leafReader.searchNearestVectors("field", queryVector, 3, null, null, 100); assertEquals(hits.scoreDocs.length, 3); assertEquals("B", storedFields.document(hits.scoreDocs[0].doc).get("id")); assertEquals("A", storedFields.document(hits.scoreDocs[1].doc).get("id")); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java index b221cb19dde6..66cdda4b2a12 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java @@ -92,7 +92,7 @@ public void testSearch() throws Exception { KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); // if this search found any results it would raise NPE attempting to collect them in our // null collector - knnVectorsReader.search("f", new float[] {1, 0}, null, null); + knnVectorsReader.search("f", new float[] {1, 0}, null, null, null); } else { fail("reader is not CodecReader"); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java index 45814144d10a..0f40b22f9cd7 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java @@ -95,11 +95,12 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { new float[] {1, 2, 3}, 10, reader.getLiveDocs(), + null, Integer.MAX_VALUE); assertEquals(0, hits.scoreDocs.length); hits = reader.searchNearestVectors( - "id", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE); + "id", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), null, Integer.MAX_VALUE); assertEquals(0, hits.scoreDocs.length); } } @@ -146,12 +147,12 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { LeafReader reader = ireader.leaves().get(0).reader(); TopDocs hits1 = reader.searchNearestVectors( - "field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE); + "field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), null, Integer.MAX_VALUE); assertEquals(1, hits1.scoreDocs.length); TopDocs hits2 = reader.searchNearestVectors( - "field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE); + "field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), null, Integer.MAX_VALUE); assertEquals(1, hits2.scoreDocs.length); } } diff --git a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java index 2023ee73391d..db08a058998e 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java @@ -17,6 +17,7 @@ package org.apache.lucene.document; import com.carrotsearch.randomizedtesting.annotations.TimeoutSuite; +import java.nio.file.Path; import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.IndexWriterConfig; @@ -24,19 +25,26 @@ import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.MatchNoDocsQuery; +import org.apache.lucene.search.Query; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; import org.apache.lucene.tests.codecs.vector.ConfigurableMCodec; import org.apache.lucene.tests.util.LuceneTestCase; import org.apache.lucene.tests.util.LuceneTestCase.Monster; +import org.junit.BeforeClass; @TimeoutSuite(millis = 86_400_000) // 24 hour timeout @Monster("takes ~10 minutes and needs extra heap, disk space, file handles") public class TestManyKnnDocs extends LuceneTestCase { // gradlew -p lucene/core test --tests TestManyKnnDocs -Ptests.heapsize=16g -Dtests.monster=true - public void testLargeSegment() throws Exception { + private static Path testDir; + + @BeforeClass + public static void init_index() throws Exception { IndexWriterConfig iwc = new IndexWriterConfig(); iwc.setCodec( new ConfigurableMCodec( @@ -46,27 +54,139 @@ public void testLargeSegment() throws Exception { mp.setMaxMergeAtOnce(256); // avoid intermediate merges (waste of time with HNSW?) mp.setSegmentsPerTier(256); // only merge once at the end when we ask iwc.setMergePolicy(mp); - String fieldName = "field"; - VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN; - try (Directory dir = FSDirectory.open(createTempDir("ManyKnnVectorDocs")); + try (Directory dir = FSDirectory.open(testDir = createTempDir("ManyKnnVectorDocs")); IndexWriter iw = new IndexWriter(dir, iwc)) { int numVectors = 2088992; - float[] vector = new float[1]; - Document doc = new Document(); - doc.add(new KnnFloatVectorField(fieldName, vector, similarityFunction)); for (int i = 0; i < numVectors; i++) { + float[] vector = new float[128]; + Document doc = new Document(); vector[0] = (i % 256); + vector[1] = (i / 256); + doc.add(new KnnFloatVectorField("field", vector, similarityFunction)); + doc.add(new KeywordField("int", "" + i, org.apache.lucene.document.Field.Store.YES)); + doc.add(new StoredField("intValue", i)); iw.addDocument(doc); } // merge to single segment and then verify iw.forceMerge(1); iw.commit(); + } + } + + public void testLargeSegmentKnn() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); - TopDocs docs = searcher.search(new KnnFloatVectorQuery("field", new float[] {120}, 10), 5); - assertEquals(5, docs.scoreDocs.length); + for (int i = 0; i < 256; i++) { + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search(new KnnFloatVectorQuery("field", vector, 10, filterQuery), 5); + assertEquals(5, docs.scoreDocs.length); + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } + } + } + + public void testLargeSegmentSeededExact() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + for (int i = 0; i < 256; i++) { + Query seedQuery = KeywordField.newExactQuery("int", "" + (i + 256)); + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search( + new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + assertEquals(5, docs.scoreDocs.length); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } + } + } + + public void testLargeSegmentSeededNearby() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + for (int i = 0; i < 256; i++) { + Query seedQuery = KeywordField.newExactQuery("int", "" + i); + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search( + new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + assertEquals(5, docs.scoreDocs.length); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } + } + } + + public void testLargeSegmentSeededDistant() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + for (int i = 0; i < 256; i++) { + Query seedQuery = KeywordField.newExactQuery("int", "" + (i + 128)); + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search( + new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + assertEquals(5, docs.scoreDocs.length); + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } + } + } + + public void testLargeSegmentSeededNone() throws Exception { + try (Directory dir = FSDirectory.open(testDir)) { + IndexSearcher searcher = new IndexSearcher(DirectoryReader.open(dir)); + for (int i = 0; i < 256; i++) { + Query seedQuery = new MatchNoDocsQuery(); + Query filterQuery = new MatchAllDocsQuery(); + float[] vector = new float[128]; + vector[0] = i; + vector[1] = 1; + TopDocs docs = + searcher.search( + new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + assertEquals(5, docs.scoreDocs.length); + Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); + String s = ""; + for (int j = 0; j < docs.scoreDocs.length - 1; j++) { + s += docs.scoreDocs[j].doc + " " + docs.scoreDocs[j].score + "\n"; + } + assertEquals(s, i + 256, d.getField("intValue").numericValue()); + } } } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java index 3c82cd6b33e4..ddfb2ddcc438 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java @@ -471,6 +471,7 @@ public void testFloatVectorValues() throws IOException { TestVectorUtil.randomVector(dimension), 5, leaf.getLiveDocs(), + null, Integer.MAX_VALUE)); } else { DocIdSetIterator iter = leaf.getFloatVectorValues("vector"); @@ -481,6 +482,7 @@ public void testFloatVectorValues() throws IOException { TestVectorUtil.randomVector(dimension), 5, leaf.getLiveDocs(), + null, Integer.MAX_VALUE); } @@ -546,6 +548,7 @@ public void testByteVectorValues() throws IOException { TestVectorUtil.randomVectorBytes(dimension), 5, leaf.getLiveDocs(), + null, Integer.MAX_VALUE)); } else { @@ -557,6 +560,7 @@ public void testByteVectorValues() throws IOException { TestVectorUtil.randomVectorBytes(dimension), 5, leaf.getLiveDocs(), + null, Integer.MAX_VALUE); } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 72be0bd929fa..553e72c66b10 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -359,7 +359,7 @@ private static TopDocs doKnnSearch(IndexReader reader, float[] vector, int k) th Bits liveDocs = ctx.reader().getLiveDocs(); results[ctx.ord] = ctx.reader() - .searchNearestVectors(KNN_GRAPH_FIELD, vector, k, liveDocs, Integer.MAX_VALUE); + .searchNearestVectors(KNN_GRAPH_FIELD, vector, k, liveDocs, null, Integer.MAX_VALUE); if (ctx.docBase > 0) { for (ScoreDoc doc : results[ctx.ord].scoreDocs) { doc.doc += ctx.docBase; diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java index 609dd0359ab5..770eea0d3711 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java @@ -25,6 +25,7 @@ import java.util.concurrent.ThreadPoolExecutor; import java.util.concurrent.TimeUnit; import org.apache.lucene.document.Document; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.MatchAllDocsQuery; @@ -123,11 +124,19 @@ public ByteVectorValues getByteVectorValues(String field) { @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) {} @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) {} @Override protected void doClose() {} diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index 7520003ab4f9..b809fb06a757 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -65,9 +65,15 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { abstract AbstractKnnVectorQuery getKnnVectorQuery( String field, float[] query, int k, Query queryFilter); + abstract AbstractKnnVectorQuery getKnnVectorQuery( + String field, float[] query, int k, Query queryFilter, Query seedQuery); + abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery( String field, float[] query, int k, Query queryFilter); + abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery( + String field, float[] query, int k, Query queryFilter, Query seedQuery); + AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k) { return getKnnVectorQuery(field, query, k, null); } @@ -608,6 +614,90 @@ public void testRandomWithFilter() throws IOException { } } + /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ + public void testRandomWithSeed() throws IOException { + int numDocs = 1000; + int dimension = atLeast(5); + int numIters = atLeast(10); + try (Directory d = newDirectory()) { + // Always use the default kNN format to have predictable behavior around when it hits + // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN + // format + // implementation. + IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); + RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + doc.add(getKnnVectorField("field", randomVector(dimension))); + doc.add(new NumericDocValuesField("tag", i)); + doc.add(new IntPoint("tag", i)); + w.addDocument(doc); + } + w.forceMerge(1); + w.close(); + + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + for (int i = 0; i < numIters; i++) { + int k = random().nextInt(80) + 1; + int n = random().nextInt(100) + 1; + + // All documents as seeds + Query seed1 = new MatchAllDocsQuery(); + AbstractKnnVectorQuery query = + getKnnVectorQuery("field", randomVector(dimension), k, null, seed1); + TopDocs results = searcher.search(query, n); + int expected = Math.min(Math.min(n, k), reader.numDocs()); + // we may get fewer results than requested if there are deletions, but this test doesn't + // test that + assert reader.hasDeletions() == false; + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value >= results.scoreDocs.length); + // verify the results are in descending score order + float last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // Restrictive seed query -- 6 documents + Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); + query = getKnnVectorQuery("field", randomVector(dimension), k, null, seed2); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + // we may get fewer results than requested if there are deletions, but this test doesn't + // test that + assert reader.hasDeletions() == false; + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // No seed documents -- falls back on full approx search + Query seed3 = new BooleanQuery.Builder().build(); + query = getKnnVectorQuery("field", randomVector(dimension), k, null, seed3); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + // we may get fewer results than requested if there are deletions, but this test doesn't + // test that + assert reader.hasDeletions() == false; + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + } + } + } + } + /** Tests filtering when all vectors have the same score. */ @AwaitsFix(bugUrl = "https://github.com/apache/lucene/issues/11787") public void testFilterWithSameScore() throws IOException { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java index 4dc3d385b087..f9ac27380ce8 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java @@ -33,9 +33,21 @@ AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Que return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter); } + @Override + AbstractKnnVectorQuery getKnnVectorQuery( + String field, float[] query, int k, Query queryFilter, Query seedQuery) { + return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter, seedQuery); + } + @Override AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { - return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query); + return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, null); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery( + String field, float[] vec, int k, Query query, Query seedQuery) { + return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, seedQuery); } @Override @@ -105,8 +117,8 @@ public void testVectorEncodingMismatch() throws IOException { private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { - public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) { - super(field, target, k, filter); + public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { + super(field, target, k, filter, seed); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index f2e5a3e274ab..71ffbc830670 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -48,9 +48,21 @@ KnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query return new KnnFloatVectorQuery(field, query, k, queryFilter); } + @Override + KnnFloatVectorQuery getKnnVectorQuery( + String field, float[] query, int k, Query queryFilter, Query seedQuery) { + return new KnnFloatVectorQuery(field, query, k, queryFilter, seedQuery); + } + @Override AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { - return new ThrowingKnnVectorQuery(field, vec, k, query); + return new ThrowingKnnVectorQuery(field, vec, k, query, null); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery( + String field, float[] vec, int k, Query query, Query seedQuery) { + return new ThrowingKnnVectorQuery(field, vec, k, query, seedQuery); } @Override @@ -254,8 +266,8 @@ public void testDocAndScoreQueryBasics() throws IOException { private static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { - public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) { - super(field, target, k, filter); + public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) { + super(field, target, k, filter, seed); } @Override diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java index f60c7966f984..e48d2aa46c3c 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java @@ -41,6 +41,7 @@ import org.apache.lucene.index.Terms; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; import org.apache.lucene.util.Version; @@ -180,11 +181,19 @@ public ByteVectorValues getByteVectorValues(String fieldName) { @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) {} @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) {} @Override public void checkIntegrity() throws IOException {} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index 456a885b49a0..cc0539c68aea 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -140,6 +140,7 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, + DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { @@ -148,7 +149,7 @@ protected TopDocs approximateSearch( if (collector == null) { return NO_RESULTS; } - context.reader().searchNearestVectors(field, query, collector, acceptDocs); + context.reader().searchNearestVectors(field, query, collector, acceptDocs, seedDocs); return collector.topDocs(); } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index 7b5a656d1414..4bb584a9e9a5 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -139,6 +139,7 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, + DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { @@ -147,7 +148,7 @@ protected TopDocs approximateSearch( if (collector == null) { return NO_RESULTS; } - context.reader().searchNearestVectors(field, query, collector, acceptDocs); + context.reader().searchNearestVectors(field, query, collector, acceptDocs, seedDocs); return collector.topDocs(); } diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index 08fb1cf6b5bd..f31cb72d3ccc 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -39,6 +39,7 @@ import org.apache.lucene.index.*; import org.apache.lucene.search.Collector; import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Query; @@ -1651,11 +1652,19 @@ public ByteVectorValues getByteVectorValues(String fieldName) { @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) {} @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) {} @Override public void checkIntegrity() throws IOException { diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java index 501e2e5616f0..b48135d52571 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java @@ -32,6 +32,7 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.Bits; @@ -146,23 +147,33 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { FieldInfo fi = fis.fieldInfo(field); assert fi != null && fi.getVectorDimension() > 0 && fi.getVectorEncoding() == VectorEncoding.FLOAT32; - delegate.search(field, target, knnCollector, acceptDocs); + delegate.search(field, target, knnCollector, acceptDocs, seedDocs); } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + public void search( + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) throws IOException { FieldInfo fi = fis.fieldInfo(field); assert fi != null && fi.getVectorDimension() > 0 && fi.getVectorEncoding() == VectorEncoding.BYTE; - delegate.search(field, target, knnCollector, acceptDocs); + delegate.search(field, target, knnCollector, acceptDocs, seedDocs); } @Override diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index a10d26423494..0febf47b2234 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -693,7 +693,12 @@ public void testDeleteAllVectorDocs() throws Exception { // assert that knn search doesn't fail on a field with all deleted docs TopDocs results = leafReader.searchNearestVectors( - "v", randomNormalizedVector(4), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE); + "v", + randomNormalizedVector(4), + 1, + leafReader.getLiveDocs(), + null, + Integer.MAX_VALUE); assertEquals(0, results.scoreDocs.length); } } @@ -1330,7 +1335,12 @@ public void testSearchWithVisitedLimit() throws Exception { TopDocs results = ctx.reader() .searchNearestVectors( - fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit); + fieldName, + randomNormalizedVector(dimension), + k, + liveDocs, + null, + visitedLimit); assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation); assertEquals(visitedLimit, results.totalHits.value); @@ -1340,7 +1350,12 @@ public void testSearchWithVisitedLimit() throws Exception { results = ctx.reader() .searchNearestVectors( - fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit); + fieldName, + randomNormalizedVector(dimension), + k, + liveDocs, + null, + visitedLimit); assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation); assertTrue(results.totalHits.value <= visitedLimit); } @@ -1419,7 +1434,12 @@ public void testRandomWithUpdatesAndGraph() throws Exception { TopDocs results = ctx.reader() .searchNearestVectors( - fieldName, randomNormalizedVector(dimension), k, liveDocs, Integer.MAX_VALUE); + fieldName, + randomNormalizedVector(dimension), + k, + liveDocs, + null, + Integer.MAX_VALUE); assertEquals(Math.min(k, size), results.scoreDocs.length); for (int i = 0; i < k - 1; i++) { assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score); diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java index 3fee110f7836..b2d215280ff5 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java @@ -41,6 +41,7 @@ import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.TermVectors; import org.apache.lucene.index.Terms; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -241,14 +242,24 @@ public ByteVectorValues getByteVectorValues(String fieldName) throws IOException @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - in.searchNearestVectors(field, target, knnCollector, acceptDocs); + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); } @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - in.searchNearestVectors(field, target, knnCollector, acceptDocs); + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); } @Override diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MismatchedLeafReader.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MismatchedLeafReader.java index ab907b768023..d7a56a33b1a7 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MismatchedLeafReader.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MismatchedLeafReader.java @@ -28,6 +28,7 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.StoredFieldVisitor; import org.apache.lucene.index.StoredFields; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -72,14 +73,24 @@ public CacheHelper getReaderCacheHelper() { @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - in.searchNearestVectors(field, target, knnCollector, acceptDocs); + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); } @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - in.searchNearestVectors(field, target, knnCollector, acceptDocs); + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) + throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); } static FieldInfos shuffleInfos(FieldInfos infos, Random random) { diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java index efd13121d930..36030f072e56 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java @@ -245,11 +245,19 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { @Override public void searchNearestVectors( - String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, + float[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) {} @Override public void searchNearestVectors( - String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} + String field, + byte[] target, + KnnCollector knnCollector, + Bits acceptDocs, + DocIdSetIterator seedDocs) {} @Override public FieldInfos getFieldInfos() { From 82e705323c3a9a8e96d14d3f0e4a9416fe173559 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Tue, 6 Aug 2024 12:09:35 +0100 Subject: [PATCH 02/41] cleanup --- .../.gradle-wrapper8297076614091580458.tmp | 0 hsperfdata_root/77 | Bin 32768 -> 0 bytes 2 files changed, 0 insertions(+), 0 deletions(-) delete mode 100644 gradle/wrapper/.gradle-wrapper8297076614091580458.tmp delete mode 100644 hsperfdata_root/77 diff --git a/gradle/wrapper/.gradle-wrapper8297076614091580458.tmp b/gradle/wrapper/.gradle-wrapper8297076614091580458.tmp deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/hsperfdata_root/77 b/hsperfdata_root/77 deleted file mode 100644 index 784b5f536d259468dc60b40f01c4e4dd04441db9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 32768 zcmeHOYm6nwRqi{p4jZxpn*=Zpfwr9x$Lr~tJGF#@X#scBUI8umEERdX+ga848KM=?tV+r9O2wRR2L_&%skU%mIEBt^TAjlu! zJEyB3)wl0sypbiOYHdHJx4v^ueRb-b>Z-c0|IaJ0wC0S~$&VR)y!HHj-%AfB&cDNX z0q6U0(y?Wnm|Ma<8H84xT2CiOq5Jfea4&FDKME5r^De&5fLj*|PQnZu4^@4nU(vS~ zg{cTV;cZ9m!AhEnUO$EEcS`|N?Y%rm{XjhF9Ed?5xR%6Gy{dX2s_Wfx;sca_U{*bk zN2$N(yE!$v?3#V^BYa+s!i~e!2~(-4QD3Wyb7SG8;zn;(bp7zU-xCD+;;T{|GJm<{3>_=ohnZetC8-mq!pm#-6Y@df zztx&2{uNx{52<1wlI4|pp*-P7%I^s0K$f3bkLQS96dHFm=`rT3{+~E|q96HT+Ce|a zskrcRW}#9a^qY!56*%fKA3TPhTn9l;eg2o1mQqh+uJB7LpI0+OWIC3b_^IKl)y%P9 zN*`eXlk{b{ym~oO>cy?#)tKXcO@2dm_C&fLd0F}Q|NL_dPZ#0{@vxU);*VU#(qKT)?I(-Pic<%*`g^hD5DjQHHwxCD(GY9|Tc!Fz9!V5>{C> z4bYQunO`%A`aT2y`3vkmH9J+mOQfa(1U>t&zsy4;l}{CifUPP6?zeyIS~TFH$LBsmS|{`4i9J6QE7IyL5knr&hwq;bSTskbXa5%YN2TCze1CEi^av&BxRS41TH97Os0Yy=W#iqX`cRIlH4KR<{E7TzITg@4eBS## zq`y{=DSWmO3>4>tgs+j9nf)u{vu)*k8qfYP<7;+I;j_&-|IYbbmVCDBJWSEQ{9WR6 zxq;7@-Ugq`C7+BQY1gUsndYC9wrZ*mV&OVn6I& z$6Qph&#nLcQg01ksSW^JEs3M;60J}m__Ie;-No`;toEWSoceORrQ}}z;}6$ zY1{Al>BgZekmHL^4+(iYkudU?%J@jyPlJL&DN1Gecsi~hN4=;5{CBSMblefXB)XMU z(dXu|p}A})#jmBE&aj(tB)pYj7sdGqyleu)#$Pkz*TpbH8^(*9G-Fl933{&et5N+s zjt7?6iuT9;h#*Crq311QzM-99|5ZPX;vE`77ayDR1i)eAEbX6!zk+%0c<2O~h(J2p zbUY$-T{(2-`qg%uZ}1bwZ@Ozyk^-0e_YAl}>Yz7E+%B+msJJG-Ot?c7&Me0z({b#V z#Z2x!p<XgIA zU#0#gd>LoBCY>PaBcc7v|7OPwc;^;H3mybutwG{*BWA!;oEwXu9zT%&cdDsv*1QJX zMm3ernn!`FWJApOE8|wG_t<>Ib|TbqCdwH4;v4t#gNM^_?|G+Ys$B@klLYHGKQGFhwr8bRbLZsFBTP?VZ5ZwPmzqV z{=3Rysve_QuMJ$uQ>TTi+ z{na>$YV%>^I=TJ{zN-W>Vduc__tn7zu+#B-i0_vXW~18)eWH;l;POLl_i}##=AYBR z*|jiZ9l7+qgtdFZaL~ei31?|{wq7suH{plrj62|%OUNvvbwQT&K^k+Bf^)X+7 ziD;J=lqXlcCa0PY8|5ezK&N4RmaoT^(?E$$g6J}a&qjN4YMy=bGX$^UE9+an6@pM& zzaO{pYpjGvxAz4tkFCq~Xwo;P-VL};Drhqrj=%Y{ zrnodt@0EO+$9U?!3^#UgBfa+)eR0wYY|t zSdUo682g}u$x~t8j(vg;a@0x3OdPOAH=!Zx~p{wcaHN|4nQJFk1YOc z*ux4pCOGpQ_}MDu7B-@rH~csqIDvVU`<^g6)pA9BXwxOhr&XsW`ErAF`#c6e7p!)X zCn-)6Zg~u@XJHwD_g3f7)8wMXJT`3>ous#y)dWd1+**IkJa=|s)lbb1R*FPy-nd%G zxkt(&;Z63F9ibP+#)cdC{Y026Uer(hK@a8W$lK>6p2;4H3pF|3f#r$aLDz^yATYy2 zutGomh=j&X^!-o$uV^0nV;A$)R^ac(PJCqbopj$=B-x92MKb8f$M5uUM*3~sYx?O| z1taZTvI_u?n>fi|Xi!0NehT;5sBWc2() z;h8<8l1EG-zcd!WJeGAHF=?MEvH;8N<4X1tZzT6D?hW>;>F|7z9u5!B_!pe;Y=8H{xg$|2NT%g~2F@TvmW)XN`f&`KazCWavt(a` zACe#JN7ydN9H1N>G@he+Q9LBv2XSW}5nP77%YoC)hC(;tU{vm-xKp_Ym2xXFiY?r{ zdGWfBuR9MZx{s+yF%-e2wn^_#Kaq%h_`kWvuU6a?GnSZg$3Db_Rw%HcmP#*RF?m58 z!jIDA(k*lz`u;|CjTiA?QMX8=2)nAuO4oAw{-u@e?S&^Ee|+WY#`qZ{J(WNIk z3oD(qD_8Izm<@;_1fCop;eGyw_!%JHToPgaYKTxqKVJJ3`VsZ3b2upv|!HhBcS&OgjD5gH(pcR!N+&q(>Fai=~6`S}?8 zhD@~Ma;5oDTW0Um?T?syQEcLe34_B0r{|qJYi@mVLn-ixxjXQK)Wqz_%Ajj5pEZ}4 zmd}_=OBeZ{xqDq6{0zWXn} zMBn^*{-bx;?>ukgPW9tJ%)era5ss-v*O7g@yluMtYBh{+{l~`Lgw5k8J(C==rY?sj zq}aT{&HQUs=vTh6hHjD@o~bBvwsi=tcrh36T|Q zl0)`VJJUh-Q(PzepYi*2V$C4?ru_oqHyz{ZTPbf*$IVs!`9H-Hnc?+SGoTsJ3}^;4 z1DXNNfM!55pc&8%Xa+O`ngPv#WGy|Fe&46Y= zGoTsJ3}^;41DXNNfM!55pc&8%Xa+O`ngPv#W VGy|Fe&46Y=GoTsJ47^he{2!^$TUh`A From 7955148204dc5f284e6e712812789279e9139c70 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Tue, 6 Aug 2024 15:12:28 +0100 Subject: [PATCH 03/41] ensure seed docs have a vector per https://github.com/apache/lucene/pull/13635#discussion_r1705418758 --- .../org/apache/lucene/search/AbstractKnnVectorQuery.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index db939db0d324..4455c4a3bcd4 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -96,7 +96,12 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { final Weight seedWeight; if (seed != null) { - Query seedRewritten = indexSearcher.rewrite(seed); + BooleanQuery booleanSeedQuery = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER) + .build(); + Query seedRewritten = indexSearcher.rewrite(booleanSeedQuery); seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); } else { seedWeight = null; From 40d972d3f10d8c81387bc32e6e8d4519c3b20ee9 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Tue, 6 Aug 2024 15:20:34 +0100 Subject: [PATCH 04/41] apply filter to seed queries per https://github.com/apache/lucene/pull/13635#discussion_r1705415253 --- .../apache/lucene/search/AbstractKnnVectorQuery.java | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 4455c4a3bcd4..c47bbcc07f03 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -96,12 +96,14 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { final Weight seedWeight; if (seed != null) { - BooleanQuery booleanSeedQuery = + BooleanQuery.Builder booleanSeedQueryBuilder = new BooleanQuery.Builder() .add(seed, BooleanClause.Occur.MUST) - .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER) - .build(); - Query seedRewritten = indexSearcher.rewrite(booleanSeedQuery); + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); } else { seedWeight = null; From f36a4cd36c2b1db81d2587f4b93a810f5dee069b Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 5 Sep 2024 14:47:10 +0100 Subject: [PATCH 05/41] Update lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java Co-authored-by: Christine Poerschke --- .../src/java/org/apache/lucene/codecs/KnnVectorsReader.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java index 375f547e5182..98a7f9d42f54 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java @@ -83,7 +83,7 @@ protected KnnVectorsReader() {} * @param knnCollector a KnnResults collector and relevant settings for gathering vector results * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. - * @param seedDocs {@link Bits} that represents the documents used to seed the search, or {@code + * @param seedDocs {@link DocIdSetIterator} that represents the documents used to seed the search, or {@code * null} to perform a search without seeds. */ public abstract void search( From 539b29a3f8366fd152d15d1082ebc25e431fb3d2 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 5 Sep 2024 14:47:22 +0100 Subject: [PATCH 06/41] Update lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java Co-authored-by: Christine Poerschke --- .../src/java/org/apache/lucene/codecs/KnnVectorsReader.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java index 98a7f9d42f54..9d57705b2604 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java @@ -118,7 +118,7 @@ public abstract void search( * @param knnCollector a KnnResults collector and relevant settings for gathering vector results * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. - * @param seedDocs {@link Bits} that represents the documents used to seed the search, or {@code + * @param seedDocs {@link DocIdSetIterator} that represents the documents used to seed the search, or {@code * null} to perform a search without seeds. */ public abstract void search( From 3df6ad254bd354bd4d02f76074ff65618127eeb5 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 5 Sep 2024 14:47:29 +0100 Subject: [PATCH 07/41] Update lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java Co-authored-by: Christine Poerschke --- .../src/java/org/apache/lucene/search/KnnByteVectorQuery.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 8e993cd93271..88b28d99880b 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -79,7 +79,7 @@ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) { * Find the k nearest documents to the target vector according to the vectors in the * given field. target vector. * - * @param field a field that has been indexed as a {@link KnnFloatVectorField}. + * @param field a field that has been indexed as a {@link KnnByteVectorField}. * @param target the target of the search * @param k the number of documents to find * @param filter a filter applied before the vector search From 0508d87f0f7e18ca6acfd00bf1383e677f07929c Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 5 Sep 2024 14:47:36 +0100 Subject: [PATCH 08/41] Update lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java Co-authored-by: Christine Poerschke --- .../java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 46df6d0ba4ad..20e52a66f232 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -91,8 +91,9 @@ public static void search( Bits acceptOrds, DocIdSetIterator entryPointOrds) throws IOException { - ArrayList entryPointOrdInts = new ArrayList(); + ArrayList entryPointOrdInts = null; if (entryPointOrds != null) { + entryPointOrdInts = new ArrayList(); int entryPointOrdInt; while ((entryPointOrdInt = entryPointOrds.nextDoc()) != NO_MORE_DOCS) { entryPointOrdInts.add(entryPointOrdInt); From 732f69cf87de678e3e98f58599c8b627b62390e7 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 5 Sep 2024 14:47:49 +0100 Subject: [PATCH 09/41] Update lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java Co-authored-by: Christine Poerschke --- .../src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 20e52a66f232..63d5ff957dfb 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -99,7 +99,7 @@ public static void search( entryPointOrdInts.add(entryPointOrdInt); } } - if (entryPointOrdInts.size() == 0) { + if (entryPointOrdInts == null || entryPointOrdInts.isEmpty()) { search(scorer, knnCollector, graph, acceptOrds); } else { HnswGraphSearcher graphSearcher = From c02b4cce49166a3d9fd82f0de62f218b62b5db79 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 5 Sep 2024 14:49:09 +0100 Subject: [PATCH 10/41] Update lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java Co-authored-by: Christine Poerschke --- .../java/org/apache/lucene/search/AbstractKnnVectorQuery.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index c47bbcc07f03..bfe0d40237e7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -586,7 +586,7 @@ private class TopDocsDISI extends DocIdSetIterator { List sortedDocIdList; public TopDocsDISI(TopDocs topDocs) { - sortedDocIdList = new ArrayList(); + sortedDocIdList = new ArrayList(topDocs.scoreDocs.length); for (int i = 0; i < topDocs.scoreDocs.length; i++) { sortedDocIdList.add(topDocs.scoreDocs[i].doc); } From 285ebfe1008eaa5b7c8ea563aa079e9ba3661e2a Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 5 Sep 2024 14:51:18 +0100 Subject: [PATCH 11/41] Update lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java Co-authored-by: Christine Poerschke --- .../java/org/apache/lucene/search/AbstractKnnVectorQuery.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index bfe0d40237e7..d92ba3acc5ae 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -201,7 +201,7 @@ private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeig if (seedWeight != null) { // Execute the seed query TopScoreDocCollector seedCollector = - new TopScoreDocCollectorManager(k, Integer.MAX_VALUE).newCollector(); + new TopScoreDocCollectorManager(k /* numHits */, null /* after */, Integer.MAX_VALUE /* totalHitsThreshold */, false /* supportsConcurrency */).newCollector(); LeafCollector leafCollector; try { leafCollector = seedCollector.getLeafCollector(ctx); From 9f1be670d274a7fcf554d0eb1aab89edc6ce2580 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Tue, 10 Sep 2024 09:57:44 +0100 Subject: [PATCH 12/41] mapping docIds to ordinals --- .../backward_codecs/lucene80/IndexedDISI.java | 55 +++++++++++++++++++ .../lucene92/OffHeapFloatVectorValues.java | 10 ++++ .../lucene94/OffHeapByteVectorValues.java | 10 ++++ .../lucene94/OffHeapFloatVectorValues.java | 10 ++++ .../lucene/codecs/KnnVectorsReader.java | 8 +-- .../lucene/codecs/lucene90/IndexedDISI.java | 55 +++++++++++++++++++ .../lucene95/OffHeapByteVectorValues.java | 10 ++++ .../lucene95/OffHeapFloatVectorValues.java | 10 ++++ .../apache/lucene/index/ByteVectorValues.java | 12 ++++ .../lucene/index/FloatVectorValues.java | 12 ++++ .../lucene/search/AbstractKnnVectorQuery.java | 18 +++++- .../lucene/search/KnnByteVectorQuery.java | 14 +++++ .../lucene/search/KnnFloatVectorQuery.java | 14 +++++ .../search/BaseKnnVectorQueryTestCase.java | 6 +- 14 files changed, 237 insertions(+), 7 deletions(-) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java index 639bdbd73339..e50376ca706a 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java @@ -706,4 +706,59 @@ private static void rankSkip(IndexedDISI disi, int targetInBlock) throws IOExcep disi.word = rankWord; disi.numberOfOnes = disi.denseOrigoIndex + denseNOO; } + + /** + * Implementation of a {@link DocIdSetIterator} which maps a source iterator to the indexes + * produced by an {@link IndexedDISI}. + * + *

    This implementation assumes that all IDs produced by the source iterator are also present in + * the indexed iterator. + * + * @lucene.internal + */ + public static class MappedDISI extends DocIdSetIterator { + IndexedDISI indexedDISI; + DocIdSetIterator sourceDISI; + + public MappedDISI(IndexedDISI indexedDISI, DocIdSetIterator sourceDISI) { + this.indexedDISI = indexedDISI; + this.sourceDISI = sourceDISI; + } + + /** + * Advances the source iterator to the first document number that is greater than or equal to + * the provided target and returns the corresponding index. + */ + @Override + public int advance(int target) throws IOException { + int newTarget = sourceDISI.advance(target); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + + @Override + public long cost() { + return this.sourceDISI.cost(); + } + + @Override + public int docID() { + if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + return indexedDISI.index(); + } + + /** Advances to the next document in the source iterator and returns the corresponding index. */ + @Override + public int nextDoc() throws IOException { + int newTarget = sourceDISI.nextDoc(); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index 19dc82cc46d5..2027949957b2 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -251,6 +251,11 @@ public DocIdSetIterator iterator() { } }; } + + @Override + public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { + return new IndexedDISI.MappedDISI(disi, docIds); + } } private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { @@ -315,5 +320,10 @@ public Bits getAcceptOrds(Bits acceptDocs) { public VectorScorer scorer(float[] query) { return null; } + + @Override + public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { + return DocIdSetIterator.empty(); + } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index 0c909e3839df..1ab1a968a965 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -267,6 +267,11 @@ public DocIdSetIterator iterator() { } }; } + + @Override + public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { + return new IndexedDISI.MappedDISI(disi, docIds); + } } private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues { @@ -331,5 +336,10 @@ public Bits getAcceptOrds(Bits acceptDocs) { public VectorScorer scorer(byte[] query) { return null; } + + @Override + public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { + return DocIdSetIterator.empty(); + } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 91f97b8a41fa..b2392796b795 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -263,6 +263,11 @@ public DocIdSetIterator iterator() { } }; } + + @Override + public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { + return new IndexedDISI.MappedDISI(disi, docIds); + } } private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { @@ -327,5 +332,10 @@ public Bits getAcceptOrds(Bits acceptDocs) { public VectorScorer scorer(float[] query) { return null; } + + @Override + public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { + return DocIdSetIterator.empty(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java index 9d57705b2604..b93e7b66e66b 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java @@ -83,8 +83,8 @@ protected KnnVectorsReader() {} * @param knnCollector a KnnResults collector and relevant settings for gathering vector results * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. - * @param seedDocs {@link DocIdSetIterator} that represents the documents used to seed the search, or {@code - * null} to perform a search without seeds. + * @param seedDocs {@link DocIdSetIterator} that represents the documents used to seed the search, + * or {@code null} to perform a search without seeds. */ public abstract void search( String field, @@ -118,8 +118,8 @@ public abstract void search( * @param knnCollector a KnnResults collector and relevant settings for gathering vector results * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. - * @param seedDocs {@link DocIdSetIterator} that represents the documents used to seed the search, or {@code - * null} to perform a search without seeds. + * @param seedDocs {@link DocIdSetIterator} that represents the documents used to seed the search, + * or {@code null} to perform a search without seeds. */ public abstract void search( String field, diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java index a2b2c84e12ae..9333160512d0 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java @@ -723,4 +723,59 @@ private static void rankSkip(IndexedDISI disi, int targetInBlock) throws IOExcep disi.word = rankWord; disi.numberOfOnes = disi.denseOrigoIndex + denseNOO; } + + /** + * Implementation of a {@link DocIdSetIterator} which maps a source iterator to the indexes + * produced by an {@link IndexedDISI}. + * + *

    This implementation assumes that all IDs produced by the source iterator are also present in + * the indexed iterator. + * + * @lucene.internal + */ + public static class MappedDISI extends DocIdSetIterator { + IndexedDISI indexedDISI; + DocIdSetIterator sourceDISI; + + public MappedDISI(IndexedDISI indexedDISI, DocIdSetIterator sourceDISI) { + this.indexedDISI = indexedDISI; + this.sourceDISI = sourceDISI; + } + + /** + * Advances the source iterator to the first document number that is greater than or equal to + * the provided target and returns the corresponding index. + */ + @Override + public int advance(int target) throws IOException { + int newTarget = sourceDISI.advance(target); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + + @Override + public long cost() { + return this.sourceDISI.cost(); + } + + @Override + public int docID() { + if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + return indexedDISI.index(); + } + + /** Advances to the next document in the source iterator and returns the corresponding index. */ + @Override + public int nextDoc() throws IOException { + int newTarget = sourceDISI.nextDoc(); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index f45158eadac7..0c450939eb66 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -311,6 +311,11 @@ public DocIdSetIterator iterator() { } }; } + + @Override + public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { + return new IndexedDISI.MappedDISI(disi, docIds); + } } private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues { @@ -378,5 +383,10 @@ public Bits getAcceptOrds(Bits acceptDocs) { public VectorScorer scorer(byte[] query) { return null; } + + @Override + public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { + return DocIdSetIterator.empty(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 1f61283b5002..accfe4150575 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -300,6 +300,11 @@ public DocIdSetIterator iterator() { } }; } + + @Override + public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { + return new IndexedDISI.MappedDISI(disi, docIds); + } } private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { @@ -367,5 +372,10 @@ public Bits getAcceptOrds(Bits acceptDocs) { public VectorScorer scorer(float[] query) { return null; } + + @Override + public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { + return DocIdSetIterator.empty(); + } } } diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index d33ca1ca3544..23c0f8740950 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -86,4 +86,16 @@ public static void checkField(LeafReader in, String field) { * @return a {@link VectorScorer} instance or null */ public abstract VectorScorer scorer(byte[] query) throws IOException; + + /** + * Returns a new iterator that maps the provided docIds to the vector ordinals. + * + *

    This method assumes that all docIds have corresponding orginals. + * + * @lucene.internal + * @lucene.experimental + */ + public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { + return docIds; + } } diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index e5dbc620f5c3..34d8c2b63f4b 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -87,4 +87,16 @@ public static void checkField(LeafReader in, String field) { * @return a {@link VectorScorer} instance or null */ public abstract VectorScorer scorer(float[] query) throws IOException; + + /** + * Returns a new iterator that maps the provided docIds to the vector ordinals. + * + *

    This method assumes that all docIds have corresponding orginals. + * + * @lucene.internal + * @lucene.experimental + */ + public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { + return docIds; + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index d92ba3acc5ae..f61fe570730e 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -201,7 +201,12 @@ private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeig if (seedWeight != null) { // Execute the seed query TopScoreDocCollector seedCollector = - new TopScoreDocCollectorManager(k /* numHits */, null /* after */, Integer.MAX_VALUE /* totalHitsThreshold */, false /* supportsConcurrency */).newCollector(); + new TopScoreDocCollectorManager( + k /* numHits */, + null /* after */, + Integer.MAX_VALUE /* totalHitsThreshold */, + false /* supportsConcurrency */) + .newCollector(); LeafCollector leafCollector; try { leafCollector = seedCollector.getLeafCollector(ctx); @@ -228,12 +233,21 @@ private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeig } TopDocs seedTopDocs = seedCollector.topDocs(); - return new TopDocsDISI(seedTopDocs); + return convertDocIdsToVectorOrdinals(ctx, new TopDocsDISI(seedTopDocs)); } else { return null; } } + /** + * Returns a new iterator that maps the provided docIds to the vector ordinals. + * + * @lucene.internal + * @lucene.experimental + */ + protected abstract DocIdSetIterator convertDocIdsToVectorOrdinals( + LeafReaderContext ctx, DocIdSetIterator docIds) throws IOException; + private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException { if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) { diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 3976fc0f53e9..8ff94eca429c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -149,4 +149,18 @@ public int hashCode() { public byte[] getTargetCopy() { return ArrayUtil.copyArray(target); } + + /** + * Returns a new iterator that maps the provided docIds to the vector ordinals. + * + *

    This method assumes that all docIds have corresponding orginals. + * + * @lucene.internal + * @lucene.experimental + */ + @Override + protected DocIdSetIterator convertDocIdsToVectorOrdinals( + LeafReaderContext ctx, DocIdSetIterator docIds) throws IOException { + return ctx.reader().getByteVectorValues(field).convertDocIdsToVectorOrdinals(docIds); + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index a1f6a680ea0a..eccad9ea3516 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -152,4 +152,18 @@ public int hashCode() { public float[] getTargetCopy() { return ArrayUtil.copyArray(target); } + + /** + * Returns a new iterator that maps the provided docIds to the vector ordinals. + * + *

    This method assumes that all docIds have corresponding orginals. + * + * @lucene.internal + * @lucene.experimental + */ + @Override + protected DocIdSetIterator convertDocIdsToVectorOrdinals( + LeafReaderContext ctx, DocIdSetIterator docIds) throws IOException { + return ctx.reader().getFloatVectorValues(field).convertDocIdsToVectorOrdinals(docIds); + } } diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index b809fb06a757..8b5bd7526320 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -628,7 +628,11 @@ public void testRandomWithSeed() throws IOException { RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); for (int i = 0; i < numDocs; i++) { Document doc = new Document(); - doc.add(getKnnVectorField("field", randomVector(dimension))); + if (random() + .nextBoolean()) { // Randomly skip some vectors to test the mapping from docid to + // ordinals + doc.add(getKnnVectorField("field", randomVector(dimension))); + } doc.add(new NumericDocValuesField("tag", i)); doc.add(new IntPoint("tag", i)); w.addDocument(doc); From 244f46b94a9a6dcbbdcee2e417cd60d1c34f87b6 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Tue, 10 Sep 2024 11:03:57 +0100 Subject: [PATCH 13/41] fixed test warning --- .../src/test/org/apache/lucene/document/TestManyKnnDocs.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java index db08a058998e..480ffd949a7b 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java @@ -64,7 +64,7 @@ public static void init_index() throws Exception { float[] vector = new float[128]; Document doc = new Document(); vector[0] = (i % 256); - vector[1] = (i / 256); + vector[1] = (i / 256.); doc.add(new KnnFloatVectorField("field", vector, similarityFunction)); doc.add(new KeywordField("int", "" + i, org.apache.lucene.document.Field.Store.YES)); doc.add(new StoredField("intValue", i)); From b73e7a34881088d229516601f30512a32f54eebb Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Tue, 10 Sep 2024 11:30:52 +0100 Subject: [PATCH 14/41] fix test warning --- .../src/test/org/apache/lucene/document/TestManyKnnDocs.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java index 480ffd949a7b..2b4b8f4b00b9 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java @@ -64,7 +64,7 @@ public static void init_index() throws Exception { float[] vector = new float[128]; Document doc = new Document(); vector[0] = (i % 256); - vector[1] = (i / 256.); + vector[1] = (float)(i / 256.); doc.add(new KnnFloatVectorField("field", vector, similarityFunction)); doc.add(new KeywordField("int", "" + i, org.apache.lucene.document.Field.Store.YES)); doc.add(new StoredField("intValue", i)); From 3134132d64674f8ce7695e2063c39fdef9450cbe Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Tue, 10 Sep 2024 11:36:26 +0100 Subject: [PATCH 15/41] tidy --- .../src/test/org/apache/lucene/document/TestManyKnnDocs.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java index 2b4b8f4b00b9..05adc70596f9 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java @@ -64,7 +64,7 @@ public static void init_index() throws Exception { float[] vector = new float[128]; Document doc = new Document(); vector[0] = (i % 256); - vector[1] = (float)(i / 256.); + vector[1] = (float) (i / 256.); doc.add(new KnnFloatVectorField("field", vector, similarityFunction)); doc.add(new KeywordField("int", "" + i, org.apache.lucene.document.Field.Store.YES)); doc.add(new StoredField("intValue", i)); From 69db4d4e52c6153f3b006d673dc1acb02dcffb2f Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 26 Sep 2024 10:35:08 +0100 Subject: [PATCH 16/41] Update lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java Co-authored-by: Christine Poerschke --- .../org/apache/lucene/search/AbstractKnnVectorQuery.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index f61fe570730e..7c47c45f6718 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -595,9 +595,9 @@ public int hashCode() { } } - private class TopDocsDISI extends DocIdSetIterator { - int idx = -1; - List sortedDocIdList; + private static class TopDocsDISI extends DocIdSetIterator { + private final List sortedDocIdList; + private int idx = -1; public TopDocsDISI(TopDocs topDocs) { sortedDocIdList = new ArrayList(topDocs.scoreDocs.length); From fe4bef3a00f9cea7a2ecb8f3c9f3a7614a18bb46 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 26 Sep 2024 10:35:20 +0100 Subject: [PATCH 17/41] Update lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java Co-authored-by: Christine Poerschke --- .../src/java/org/apache/lucene/search/KnnByteVectorQuery.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 8ff94eca429c..c11c63a39ecd 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -153,7 +153,7 @@ public byte[] getTargetCopy() { /** * Returns a new iterator that maps the provided docIds to the vector ordinals. * - *

    This method assumes that all docIds have corresponding orginals. + *

    This method assumes that all docIds have corresponding ordinals. * * @lucene.internal * @lucene.experimental From 8e044f8ea48327dc064511f586fe060eaac86b97 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 26 Sep 2024 10:35:26 +0100 Subject: [PATCH 18/41] Update lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java Co-authored-by: Christine Poerschke --- .../src/java/org/apache/lucene/search/KnnFloatVectorQuery.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index eccad9ea3516..4f9ffc265e06 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -156,7 +156,7 @@ public float[] getTargetCopy() { /** * Returns a new iterator that maps the provided docIds to the vector ordinals. * - *

    This method assumes that all docIds have corresponding orginals. + *

    This method assumes that all docIds have corresponding ordinals. * * @lucene.internal * @lucene.experimental From 33231b32bb611075096f6f0394e152b347a03a85 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 26 Sep 2024 10:36:04 +0100 Subject: [PATCH 19/41] Update lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java Co-authored-by: Christine Poerschke --- .../src/java/org/apache/lucene/index/FloatVectorValues.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 34d8c2b63f4b..554def9ae99a 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -91,7 +91,7 @@ public static void checkField(LeafReader in, String field) { /** * Returns a new iterator that maps the provided docIds to the vector ordinals. * - *

    This method assumes that all docIds have corresponding orginals. + *

    This method assumes that all docIds have corresponding ordinals. * * @lucene.internal * @lucene.experimental From 2e86e4f0b78302b911701d0da820b816194a32d6 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 26 Sep 2024 10:36:13 +0100 Subject: [PATCH 20/41] Update lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java Co-authored-by: Christine Poerschke --- .../core/src/java/org/apache/lucene/index/ByteVectorValues.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index 23c0f8740950..c22cb296a5e3 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -90,7 +90,7 @@ public static void checkField(LeafReader in, String field) { /** * Returns a new iterator that maps the provided docIds to the vector ordinals. * - *

    This method assumes that all docIds have corresponding orginals. + *

    This method assumes that all docIds have corresponding ordinals. * * @lucene.internal * @lucene.experimental From c0c18b2ceac331e73015cd378a042a0378b1ed94 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 26 Sep 2024 10:36:29 +0100 Subject: [PATCH 21/41] Update lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java Co-authored-by: Christine Poerschke --- .../src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 63d5ff957dfb..a0e7f3a0e825 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -73,7 +73,7 @@ public static void search( } /** - * Searches the HNSW graph for for the nerest neighbors of a query vector, starting from the + * Searches the HNSW graph for the nearest neighbors of a query vector, starting from the * provided entry points. * * @param scorer the scorer to compare the query with the nodes From 6190aca8990b273b18c7438ef633549bb5996ddf Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 26 Sep 2024 10:37:26 +0100 Subject: [PATCH 22/41] Update lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java Co-authored-by: Christine Poerschke --- .../src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index a0e7f3a0e825..842a287d5630 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -104,7 +104,7 @@ public static void search( } else { HnswGraphSearcher graphSearcher = new HnswGraphSearcher( - new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(graph.size())); + new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph))); int[] entryPointOrdIntsArr = entryPointOrdInts.stream().mapToInt(Integer::intValue).toArray(); graphSearcher.searchLevel(knnCollector, scorer, 0, entryPointOrdIntsArr, graph, acceptOrds); } From a49ba2f7030216a34869164084fcfd7861670b7e Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 26 Sep 2024 10:49:59 +0100 Subject: [PATCH 23/41] address review comments --- .../org/apache/lucene/index/LeafReader.java | 4 +++ .../lucene/search/AbstractKnnVectorQuery.java | 25 +++++++------------ .../lucene/search/KnnByteVectorQuery.java | 6 ++--- .../lucene/search/KnnFloatVectorQuery.java | 4 +-- .../lucene/util/hnsw/HnswGraphSearcher.java | 11 +++++--- 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java index 0988c7d289af..d0445c575ac4 100644 --- a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java @@ -247,6 +247,8 @@ public final PostingsEnum postings(Term term) throws IOException { * @param k the number of docs to return * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. + * @param seedDocs candidate documents to seed the KNN search, or {@code null} to search without + * using seeds. * @param visitedLimit the maximum number of nodes that the search is allowed to visit * @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores. * @lucene.experimental @@ -297,6 +299,8 @@ public final TopDocs searchNearestVectors( * @param k the number of docs to return * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. + * @param seedDocs candidate documents to seed the KNN search, or {@code null} to search without + * using seeds. * @param visitedLimit the maximum number of nodes that the search is allowed to visit * @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores. * @lucene.experimental diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 7c47c45f6718..da13866b6c30 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -210,6 +210,13 @@ private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeig LeafCollector leafCollector; try { leafCollector = seedCollector.getLeafCollector(ctx); + if (leafCollector != null) { + BulkScorer scorer = seedWeight.bulkScorer(ctx); + if (scorer != null) { + scorer.score(leafCollector, ctx.reader().getLiveDocs()); + } + leafCollector.finish(); + } } catch ( @SuppressWarnings("unused") CollectionTerminatedException e) { @@ -217,23 +224,9 @@ private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeig // continue with the following leaf leafCollector = null; } - if (leafCollector != null) { - BulkScorer scorer = seedWeight.bulkScorer(ctx); - if (scorer != null) { - try { - scorer.score(leafCollector, ctx.reader().getLiveDocs()); - } catch ( - @SuppressWarnings("unused") - CollectionTerminatedException e) { - // collection was terminated prematurely - // continue with the following leaf - } - } - leafCollector.finish(); - } TopDocs seedTopDocs = seedCollector.topDocs(); - return convertDocIdsToVectorOrdinals(ctx, new TopDocsDISI(seedTopDocs)); + return convertDocIdsToVectorOrdinals(ctx.reader(), new TopDocsDISI(seedTopDocs)); } else { return null; } @@ -246,7 +239,7 @@ private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeig * @lucene.experimental */ protected abstract DocIdSetIterator convertDocIdsToVectorOrdinals( - LeafReaderContext ctx, DocIdSetIterator docIds) throws IOException; + LeafReader reader, DocIdSetIterator docIds) throws IOException; private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException { diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index c11c63a39ecd..55ca449a28af 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -153,14 +153,14 @@ public byte[] getTargetCopy() { /** * Returns a new iterator that maps the provided docIds to the vector ordinals. * - *

    This method assumes that all docIds have corresponding ordinals. + *

    This method assumes that all docIds have corresponding . * * @lucene.internal * @lucene.experimental */ @Override protected DocIdSetIterator convertDocIdsToVectorOrdinals( - LeafReaderContext ctx, DocIdSetIterator docIds) throws IOException { - return ctx.reader().getByteVectorValues(field).convertDocIdsToVectorOrdinals(docIds); + LeafReader reader, DocIdSetIterator docIds) throws IOException { + return reader.getByteVectorValues(field).convertDocIdsToVectorOrdinals(docIds); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 4f9ffc265e06..7d0052d1325b 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -163,7 +163,7 @@ public float[] getTargetCopy() { */ @Override protected DocIdSetIterator convertDocIdsToVectorOrdinals( - LeafReaderContext ctx, DocIdSetIterator docIds) throws IOException { - return ctx.reader().getFloatVectorValues(field).convertDocIdsToVectorOrdinals(docIds); + LeafReader reader, DocIdSetIterator docIds) throws IOException { + return reader.getFloatVectorValues(field).convertDocIdsToVectorOrdinals(docIds); } } diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 842a287d5630..1005500dfe91 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -73,8 +73,8 @@ public static void search( } /** - * Searches the HNSW graph for the nearest neighbors of a query vector, starting from the - * provided entry points. + * Searches the HNSW graph for the nearest neighbors of a query vector, starting from the provided + * entry points. * * @param scorer the scorer to compare the query with the nodes * @param knnCollector a collector of top knn results to be returned @@ -104,9 +104,12 @@ public static void search( } else { HnswGraphSearcher graphSearcher = new HnswGraphSearcher( - new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph))); + new NeighborQueue(knnCollector.k(), true), + new SparseFixedBitSet(getGraphSize(graph))); int[] entryPointOrdIntsArr = entryPointOrdInts.stream().mapToInt(Integer::intValue).toArray(); - graphSearcher.searchLevel(knnCollector, scorer, 0, entryPointOrdIntsArr, graph, acceptOrds); + // We use provided entry point ordinals to search the complete graph (level 0) + graphSearcher.searchLevel( + knnCollector, scorer, 0 /* level */, entryPointOrdIntsArr, graph, acceptOrds); } } From 1f8a9f4996705ff9b006dca05bdd6dc74a89c595 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Thu, 26 Sep 2024 11:15:18 +0100 Subject: [PATCH 24/41] merge issues --- .../org/apache/lucene/search/AbstractKnnVectorQuery.java | 6 +++++- .../apache/lucene/search/BaseKnnVectorQueryTestCase.java | 6 +++--- .../lucene/tests/index/BaseKnnVectorsFormatTestCase.java | 2 +- 3 files changed, 9 insertions(+), 5 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index bca3871c2f7f..337a4bdde887 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -213,7 +213,11 @@ private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeig if (leafCollector != null) { BulkScorer scorer = seedWeight.bulkScorer(ctx); if (scorer != null) { - scorer.score(leafCollector, ctx.reader().getLiveDocs()); + scorer.score( + leafCollector, + ctx.reader().getLiveDocs(), + 0 /* min */, + DocIdSetIterator.NO_MORE_DOCS /* max */); } leafCollector.finish(); } diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index 7dd82dc61c49..a5325a009d09 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -656,7 +656,7 @@ public void testRandomWithSeed() throws IOException { // test that assert reader.hasDeletions() == false; assertEquals(expected, results.scoreDocs.length); - assertTrue(results.totalHits.value >= results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); // verify the results are in descending score order float last = Float.MAX_VALUE; for (ScoreDoc scoreDoc : results.scoreDocs) { @@ -673,7 +673,7 @@ public void testRandomWithSeed() throws IOException { // test that assert reader.hasDeletions() == false; assertEquals(expected, results.scoreDocs.length); - assertTrue(results.totalHits.value >= results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); // verify the results are in descending score order last = Float.MAX_VALUE; for (ScoreDoc scoreDoc : results.scoreDocs) { @@ -690,7 +690,7 @@ public void testRandomWithSeed() throws IOException { // test that assert reader.hasDeletions() == false; assertEquals(expected, results.scoreDocs.length); - assertTrue(results.totalHits.value >= results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); // verify the results are in descending score order last = Float.MAX_VALUE; for (ScoreDoc scoreDoc : results.scoreDocs) { diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index 81a27217f247..fe5f084c4963 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -1466,7 +1466,7 @@ public void testSearchWithVisitedLimit() throws Exception { null, visitedLimit); assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation()); - assertTrue(results.totalHits.value <= visitedLimit()); + assertTrue(results.totalHits.value() <= visitedLimit); } } } From 58f34df7da180258b784523e1f3bc15358f98eb7 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Fri, 27 Sep 2024 11:01:14 +0100 Subject: [PATCH 25/41] addresses review comments --- .../lucene/search/AbstractKnnVectorQuery.java | 60 +++++++++---------- .../lucene/search/KnnByteVectorQuery.java | 2 +- .../tests/index/MismatchedLeafReader.java | 25 -------- 3 files changed, 28 insertions(+), 59 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 337a4bdde887..e6ff0a931851 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -198,42 +198,36 @@ private TopDocs getLeafResults( private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeight) throws IOException { - if (seedWeight != null) { - // Execute the seed query - TopScoreDocCollector seedCollector = - new TopScoreDocCollectorManager( - k /* numHits */, - null /* after */, - Integer.MAX_VALUE /* totalHitsThreshold */, - false /* supportsConcurrency */) - .newCollector(); - LeafCollector leafCollector; - try { - leafCollector = seedCollector.getLeafCollector(ctx); - if (leafCollector != null) { - BulkScorer scorer = seedWeight.bulkScorer(ctx); - if (scorer != null) { - scorer.score( - leafCollector, - ctx.reader().getLiveDocs(), - 0 /* min */, - DocIdSetIterator.NO_MORE_DOCS /* max */); - } - leafCollector.finish(); + if (seedWeight == null) return null; + // Execute the seed query + TopScoreDocCollector seedCollector = + new TopScoreDocCollectorManager( + k /* numHits */, + null /* after */, + Integer.MAX_VALUE /* totalHitsThreshold */, + false /* supportsConcurrency */) + .newCollector(); + final LeafReader leafReader = ctx.reader(); + try { + final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx); + if (leafCollector != null) { + BulkScorer scorer = seedWeight.bulkScorer(ctx); + if (scorer != null) { + scorer.score( + leafCollector, + leafReader.getLiveDocs(), + 0 /* min */, + DocIdSetIterator.NO_MORE_DOCS /* max */); } - } catch ( - @SuppressWarnings("unused") - CollectionTerminatedException e) { - // there is no doc of interest in this reader context - // continue with the following leaf - leafCollector = null; + leafCollector.finish(); } - - TopDocs seedTopDocs = seedCollector.topDocs(); - return convertDocIdsToVectorOrdinals(ctx.reader(), new TopDocsDISI(seedTopDocs)); - } else { - return null; + } catch ( + @SuppressWarnings("unused") + CollectionTerminatedException e) { } + + TopDocs seedTopDocs = seedCollector.topDocs(); + return convertDocIdsToVectorOrdinals(leafReader, new TopDocsDISI(seedTopDocs)); } /** diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 55ca449a28af..94f72efcc8c9 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -153,7 +153,7 @@ public byte[] getTargetCopy() { /** * Returns a new iterator that maps the provided docIds to the vector ordinals. * - *

    This method assumes that all docIds have corresponding . + *

    This method assumes that all docIds have corresponding ordinals. * * @lucene.internal * @lucene.experimental diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MismatchedLeafReader.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MismatchedLeafReader.java index 87a2985a4f02..46404f514c6a 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MismatchedLeafReader.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MismatchedLeafReader.java @@ -28,9 +28,6 @@ import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.StoredFieldVisitor; import org.apache.lucene.index.StoredFields; -import org.apache.lucene.search.DocIdSetIterator; -import org.apache.lucene.search.KnnCollector; -import org.apache.lucene.util.Bits; /** * Shuffles field numbers around to try to trip bugs where field numbers are assumed to always be @@ -71,28 +68,6 @@ public CacheHelper getReaderCacheHelper() { return in.getReaderCacheHelper(); } - @Override - public void searchNearestVectors( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException { - in.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); - } - - @Override - public void searchNearestVectors( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException { - in.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); - } - static FieldInfos shuffleInfos(FieldInfos infos, Random random) { // first, shuffle the order List shuffled = new ArrayList<>(); From fc2129f7049917c5e5cfc174e02f680ed4d3b033 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 2 Oct 2024 10:30:02 +0100 Subject: [PATCH 26/41] refactor wip --- .../lucene90/Lucene90HnswVectorsReader.java | 14 +--- .../lucene91/Lucene91HnswVectorsReader.java | 17 +---- .../lucene92/Lucene92HnswVectorsReader.java | 15 +--- .../lucene94/Lucene94HnswVectorsReader.java | 21 +----- .../lucene95/Lucene95HnswVectorsReader.java | 21 +----- .../SimpleTextKnnVectorsReader.java | 14 +--- .../bitvectors/TestHnswBitVectorsFormat.java | 2 +- .../lucene/codecs/KnnVectorsFormat.java | 13 +--- .../lucene/codecs/KnnVectorsReader.java | 19 +---- .../lucene/codecs/hnsw/FlatVectorsReader.java | 15 +--- .../lucene99/Lucene99HnswVectorsReader.java | 15 +--- .../perfield/PerFieldKnnVectorsFormat.java | 19 +---- .../org/apache/lucene/index/CheckIndex.java | 6 +- .../org/apache/lucene/index/CodecReader.java | 19 +---- .../lucene/index/DocValuesLeafReader.java | 15 +--- .../lucene/index/ExitableDirectoryReader.java | 16 +--- .../apache/lucene/index/FilterLeafReader.java | 19 +---- .../org/apache/lucene/index/LeafReader.java | 43 ++--------- .../lucene/index/ParallelLeafReader.java | 17 +---- .../lucene/index/SlowCodecReaderWrapper.java | 19 +---- .../SlowCompositeCodecReaderWrapper.java | 14 +--- .../lucene/index/SortingCodecReader.java | 15 +--- .../search/ByteVectorSimilarityQuery.java | 2 +- .../search/FloatVectorSimilarityQuery.java | 2 +- .../lucene/search/KnnByteVectorQuery.java | 9 ++- .../apache/lucene/search/KnnCollector.java | 5 ++ .../lucene/search/KnnFloatVectorQuery.java | 9 ++- .../lucene/search/SeededKnnCollector.java | 73 +++++++++++++++++++ .../lucene/util/hnsw/HnswGraphSearcher.java | 40 ++-------- ...estLucene99HnswQuantizedVectorsFormat.java | 2 +- ...stLucene99ScalarQuantizedVectorScorer.java | 2 +- ...tLucene99ScalarQuantizedVectorsFormat.java | 2 +- .../TestPerFieldKnnVectorsFormat.java | 7 +- .../index/TestExitableDirectoryReader.java | 4 - .../org/apache/lucene/index/TestKnnGraph.java | 2 +- .../index/TestSegmentToThreadMapping.java | 13 +--- .../highlight/TermVectorLeafReader.java | 13 +--- ...iversifyingChildrenByteKnnVectorQuery.java | 2 +- ...versifyingChildrenFloatKnnVectorQuery.java | 2 +- .../lucene/index/memory/MemoryIndex.java | 12 +-- .../asserting/AssertingKnnVectorsFormat.java | 19 +---- .../index/BaseKnnVectorsFormatTestCase.java | 28 +------ .../tests/index/MergeReaderWrapper.java | 19 +---- .../lucene/tests/search/QueryUtils.java | 12 +-- 44 files changed, 193 insertions(+), 454 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/SeededKnnCollector.java diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java index 9e1753bc6465..665d31403214 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene90/Lucene90HnswVectorsReader.java @@ -236,12 +236,7 @@ public ByteVectorValues getByteVectorValues(String field) { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -273,12 +268,7 @@ public void search( } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index 58931c9d7976..d73e194a24dc 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -230,12 +230,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -251,17 +246,11 @@ public void search( scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - getAcceptOrds(acceptDocs, fieldEntry), - seedDocs); + getAcceptOrds(acceptDocs, fieldEntry)); } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocsBits, - DocIdSetIterator seedDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocsBits) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java index aa285ce8df19..39fe109a9f13 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/Lucene92HnswVectorsReader.java @@ -34,7 +34,6 @@ import org.apache.lucene.index.IndexFileNames; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -228,12 +227,7 @@ public ByteVectorValues getByteVectorValues(String field) { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -253,12 +247,7 @@ public void search( } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java index c514f34f01d0..d5beae1e6811 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/Lucene94HnswVectorsReader.java @@ -35,7 +35,6 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -268,12 +267,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -289,17 +283,11 @@ public void search( scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs), - seedDocs); + vectorValues.getAcceptOrds(acceptDocs)); } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -315,8 +303,7 @@ public void search( scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs), - seedDocs); + vectorValues.getAcceptOrds(acceptDocs)); } private HnswGraph getGraph(FieldEntry entry) throws IOException { diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java index f989b2e54d5d..2e6714d6eb8e 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene95/Lucene95HnswVectorsReader.java @@ -39,7 +39,6 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -292,12 +291,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -324,17 +318,11 @@ public void search( scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs), - seedDocs); + vectorValues.getAcceptOrds(acceptDocs)); } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldEntry fieldEntry = fields.get(field); @@ -361,8 +349,7 @@ public void search( scorer, new OrdinalTranslatedKnnCollector(knnCollector, vectorValues::ordToDoc), getGraph(fieldEntry), - vectorValues.getAcceptOrds(acceptDocs), - seedDocs); + vectorValues.getAcceptOrds(acceptDocs)); } /** Get knn graph values; used for testing */ diff --git a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java index 74f3a15b4db6..faba629715b7 100644 --- a/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java +++ b/lucene/codecs/src/java/org/apache/lucene/codecs/simpletext/SimpleTextKnnVectorsReader.java @@ -180,12 +180,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FloatVectorValues values = getFloatVectorValues(field); if (target.length != values.dimension()) { @@ -215,12 +210,7 @@ public void search( } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { ByteVectorValues values = getByteVectorValues(field); if (target.length != values.dimension()) { diff --git a/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java b/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java index ec2bdecb1eb9..ab20ee67c8c9 100644 --- a/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java +++ b/lucene/codecs/src/test/org/apache/lucene/codecs/bitvectors/TestHnswBitVectorsFormat.java @@ -89,7 +89,7 @@ public void testIndexAndSearchBitVectors() throws IOException { try (IndexReader reader = DirectoryReader.open(w)) { LeafReader r = getOnlyLeafReader(reader); TopKnnCollector collector = new TopKnnCollector(3, Integer.MAX_VALUE); - r.searchNearestVectors("v1", vectors[0], collector, null, null); + r.searchNearestVectors("v1", vectors[0], collector, null); TopDocs topDocs = collector.topDocs(); assertEquals(3, topDocs.scoreDocs.length); diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java index 51dd9a83f424..ad6e4aba607c 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsFormat.java @@ -23,7 +23,6 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; import org.apache.lucene.util.NamedSPILoader; @@ -139,21 +138,13 @@ public ByteVectorValues getByteVectorValues(String field) { @Override public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) { + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) { throw new UnsupportedOperationException(); } @Override public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) { + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java index b93e7b66e66b..e054ebeb2bb1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/KnnVectorsReader.java @@ -22,7 +22,6 @@ import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -83,16 +82,9 @@ protected KnnVectorsReader() {} * @param knnCollector a KnnResults collector and relevant settings for gathering vector results * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. - * @param seedDocs {@link DocIdSetIterator} that represents the documents used to seed the search, - * or {@code null} to perform a search without seeds. */ public abstract void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException; + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; /** * Return the k nearest neighbor documents as determined by comparison of their vector values for @@ -118,16 +110,9 @@ public abstract void search( * @param knnCollector a KnnResults collector and relevant settings for gathering vector results * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. - * @param seedDocs {@link DocIdSetIterator} that represents the documents used to seed the search, - * or {@code null} to perform a search without seeds. */ public abstract void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException; + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; /** * Returns an instance optimized for merging. This instance may only be consumed in the thread diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java index 09243897efd5..8d6f66c54226 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java @@ -19,7 +19,6 @@ import java.io.IOException; import org.apache.lucene.codecs.KnnVectorsReader; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Accountable; import org.apache.lucene.util.Bits; @@ -57,23 +56,13 @@ public FlatVectorsScorer getFlatVectorScorer() { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { // don't scan stored field data. If we didn't index it, produce no search results } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDoc, - DocIdSetIterator seedDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDoc) throws IOException { // don't scan stored field data. If we didn't index it, produce no search results } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java index 88eca243b364..f27a826e9c35 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene99/Lucene99HnswVectorsReader.java @@ -37,7 +37,6 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.DataInput; @@ -248,12 +247,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seeds) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { search( fields.get(field), @@ -264,12 +258,7 @@ public void search( } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seeds) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { search( fields.get(field), diff --git a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java index 0c1753bce2e9..5dc4db8db6a8 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/perfield/PerFieldKnnVectorsFormat.java @@ -33,7 +33,6 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; import org.apache.lucene.util.IOUtils; @@ -271,25 +270,15 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - fields.get(field).search(field, target, knnCollector, acceptDocs, seedDocs); + fields.get(field).search(field, target, knnCollector, acceptDocs); } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - fields.get(field).search(field, target, knnCollector, acceptDocs, seedDocs); + fields.get(field).search(field, target, knnCollector, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java index dd8b7f74443e..b8256ecf5875 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java +++ b/lucene/core/src/java/org/apache/lucene/index/CheckIndex.java @@ -2769,7 +2769,7 @@ private static void checkFloatVectorValues( if (vectorsReaderSupportsSearch(codecReader, fieldInfo.name)) { codecReader .getVectorReader() - .search(fieldInfo.name, values.vectorValue(), collector, null, null); + .search(fieldInfo.name, values.vectorValue(), collector, null); TopDocs docs = collector.topDocs(); if (docs.scoreDocs.length == 0) { throw new CheckIndexException( @@ -2815,9 +2815,7 @@ private static void checkByteVectorValues( // search the first maxNumSearches vectors to exercise the graph if (supportsSearch && values.docID() % everyNdoc == 0) { KnnCollector collector = new TopKnnCollector(10, Integer.MAX_VALUE); - codecReader - .getVectorReader() - .search(fieldInfo.name, values.vectorValue(), collector, null, null); + codecReader.getVectorReader().search(fieldInfo.name, values.vectorValue(), collector, null); TopDocs docs = collector.topDocs(); if (docs.scoreDocs.length == 0) { throw new CheckIndexException( diff --git a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java index a6f78c13dce1..20be7e1a45a8 100644 --- a/lucene/core/src/java/org/apache/lucene/index/CodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/CodecReader.java @@ -25,7 +25,6 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -261,12 +260,7 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti @Override public final void searchNearestVectors( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException { + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { ensureOpen(); FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null @@ -275,17 +269,12 @@ public final void searchNearestVectors( // Field does not exist or does not index vectors return; } - getVectorReader().search(field, target, knnCollector, acceptDocs, seedDocs); + getVectorReader().search(field, target, knnCollector, acceptDocs); } @Override public final void searchNearestVectors( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException { + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { ensureOpen(); FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null @@ -294,7 +283,7 @@ public final void searchNearestVectors( // Field does not exist or does not index vectors return; } - getVectorReader().search(field, target, knnCollector, acceptDocs, seedDocs); + getVectorReader().search(field, target, knnCollector, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java index 717c0e64d0ba..3504c7429a5e 100644 --- a/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/DocValuesLeafReader.java @@ -18,7 +18,6 @@ package org.apache.lucene.index; import java.io.IOException; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -60,23 +59,13 @@ public final ByteVectorValues getByteVectorValues(String field) throws IOExcepti @Override public void searchNearestVectors( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException { + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } @Override public void searchNearestVectors( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException { + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java index 53024ba41658..ca2cb1a27d45 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ExitableDirectoryReader.java @@ -333,11 +333,7 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { @Override public void searchNearestVectors( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would @@ -365,16 +361,12 @@ public int length() { } }; - in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs, seedDocs); + in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs); } @Override public void searchNearestVectors( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { // when acceptDocs is null due to no doc deleted, we will instantiate a new one that would // match all docs to allow timeout checking. @@ -401,7 +393,7 @@ public int length() { } }; - in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs, seedDocs); + in.searchNearestVectors(field, target, knnCollector, timeoutCheckingAcceptDocs); } private void checkAndThrowForSearchVectors() { diff --git a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java index c6996cb96259..87d62f22d041 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/FilterLeafReader.java @@ -18,7 +18,6 @@ import java.io.IOException; import java.util.Iterator; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.AttributeSource; import org.apache.lucene.util.Bits; @@ -366,24 +365,14 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { @Override public void searchNearestVectors( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException { - in.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override public void searchNearestVectors( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException { - in.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java index d0445c575ac4..0f39d1ae1e8d 100644 --- a/lucene/core/src/java/org/apache/lucene/index/LeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/LeafReader.java @@ -17,7 +17,6 @@ package org.apache.lucene.index; import java.io.IOException; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -247,20 +246,12 @@ public final PostingsEnum postings(Term term) throws IOException { * @param k the number of docs to return * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. - * @param seedDocs candidate documents to seed the KNN search, or {@code null} to search without - * using seeds. * @param visitedLimit the maximum number of nodes that the search is allowed to visit * @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores. * @lucene.experimental */ public final TopDocs searchNearestVectors( - String field, - float[] target, - int k, - Bits acceptDocs, - DocIdSetIterator seedDocs, - int visitedLimit) - throws IOException { + String field, float[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null || fi.getVectorDimension() == 0) { return TopDocsCollector.EMPTY_TOPDOCS; @@ -274,7 +265,7 @@ public final TopDocs searchNearestVectors( return TopDocsCollector.EMPTY_TOPDOCS; } KnnCollector collector = new TopKnnCollector(k, visitedLimit); - searchNearestVectors(field, target, collector, acceptDocs, seedDocs); + searchNearestVectors(field, target, collector, acceptDocs); return collector.topDocs(); } @@ -299,20 +290,12 @@ public final TopDocs searchNearestVectors( * @param k the number of docs to return * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. - * @param seedDocs candidate documents to seed the KNN search, or {@code null} to search without - * using seeds. * @param visitedLimit the maximum number of nodes that the search is allowed to visit * @return the k nearest neighbor documents, along with their (searchStrategy-specific) scores. * @lucene.experimental */ public final TopDocs searchNearestVectors( - String field, - byte[] target, - int k, - Bits acceptDocs, - DocIdSetIterator seedDocs, - int visitedLimit) - throws IOException { + String field, byte[] target, int k, Bits acceptDocs, int visitedLimit) throws IOException { FieldInfo fi = getFieldInfos().fieldInfo(field); if (fi == null || fi.getVectorDimension() == 0) { return TopDocsCollector.EMPTY_TOPDOCS; @@ -326,7 +309,7 @@ public final TopDocs searchNearestVectors( return TopDocsCollector.EMPTY_TOPDOCS; } KnnCollector collector = new TopKnnCollector(k, visitedLimit); - searchNearestVectors(field, target, collector, acceptDocs, seedDocs); + searchNearestVectors(field, target, collector, acceptDocs); return collector.topDocs(); } @@ -354,17 +337,10 @@ public final TopDocs searchNearestVectors( * @param knnCollector collector with settings for gathering the vector results. * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. - * @param seedDocs {@link Bits} that represents an initial set of documents to seed the search, or - * {@code null} if a full search is to be conducted. * @lucene.experimental */ public abstract void searchNearestVectors( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException; + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; /** * Return the k nearest neighbor documents as determined by comparison of their vector values for @@ -390,17 +366,10 @@ public abstract void searchNearestVectors( * @param knnCollector collector with settings for gathering the vector results. * @param acceptDocs {@link Bits} that represents the allowed documents to match, or {@code null} * if they are all allowed to match. - * @param seedDocs {@link Bits} that represents an initial set of documents to seed the search, or - * {@code null} if a full search is to be conducted. * @lucene.experimental */ public abstract void searchNearestVectors( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException; + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException; /** * Get the {@link FieldInfos} describing all fields in this reader. diff --git a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java index c273f8418a7a..c3ace74fb5b6 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/ParallelLeafReader.java @@ -26,7 +26,6 @@ import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; import org.apache.lucene.util.Bits; @@ -457,31 +456,23 @@ public ByteVectorValues getByteVectorValues(String fieldName) throws IOException @Override public void searchNearestVectors( - String fieldName, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + String fieldName, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { ensureOpen(); LeafReader reader = fieldToReader.get(fieldName); if (reader != null) { - reader.searchNearestVectors(fieldName, target, knnCollector, acceptDocs, seedDocs); + reader.searchNearestVectors(fieldName, target, knnCollector, acceptDocs); } } @Override public void searchNearestVectors( - String fieldName, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + String fieldName, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { ensureOpen(); LeafReader reader = fieldToReader.get(fieldName); if (reader != null) { - reader.searchNearestVectors(fieldName, target, knnCollector, acceptDocs, seedDocs); + reader.searchNearestVectors(fieldName, target, knnCollector, acceptDocs); } } diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java index 68d15f8f7d58..57836ae482fd 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCodecReaderWrapper.java @@ -28,7 +28,6 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -174,25 +173,15 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - reader.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); + reader.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { - reader.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); + reader.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java index 21119c0d34ad..b2f7f21fb7ed 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java +++ b/lucene/core/src/java/org/apache/lucene/index/SlowCompositeCodecReaderWrapper.java @@ -959,23 +959,13 @@ public VectorScorer scorer(byte[] target) { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seeds) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seeds) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java index c36ec2fbb9da..fee0fc2f7309 100644 --- a/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java +++ b/lucene/core/src/java/org/apache/lucene/index/SortingCodecReader.java @@ -32,7 +32,6 @@ import org.apache.lucene.codecs.PointsReader; import org.apache.lucene.codecs.StoredFieldsReader; import org.apache.lucene.codecs.TermVectorsReader; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.Sort; import org.apache.lucene.search.SortField; @@ -513,22 +512,12 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) { + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) { throw new UnsupportedOperationException(); } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) { + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java index 9ff12994103f..c547f1face7b 100644 --- a/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/ByteVectorSimilarityQuery.java @@ -114,7 +114,7 @@ protected TopDocs approximateSearch( KnnCollectorManager knnCollectorManager) throws IOException { KnnCollector collector = knnCollectorManager.newCollector(visitLimit, context); - context.reader().searchNearestVectors(field, target, collector, acceptDocs, null); + context.reader().searchNearestVectors(field, target, collector, acceptDocs); return collector.topDocs(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java index ebbd30e01f5d..4c7078ac1404 100644 --- a/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/FloatVectorSimilarityQuery.java @@ -116,7 +116,7 @@ protected TopDocs approximateSearch( KnnCollectorManager knnCollectorManager) throws IOException { KnnCollector collector = knnCollectorManager.newCollector(visitLimit, context); - context.reader().searchNearestVectors(field, target, collector, acceptDocs, null); + context.reader().searchNearestVectors(field, target, collector, acceptDocs); return collector.topDocs(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 94f72efcc8c9..35e8bcd2a955 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -30,8 +30,8 @@ import org.apache.lucene.util.Bits; /** - * Uses {@link KnnVectorsReader#search(String, byte[], KnnCollector, Bits, DocIdSetIterator)} to - * perform nearest neighbour search. + * Uses {@link KnnVectorsReader#search(String, byte[], KnnCollector, Bits)} to perform nearest + * neighbour search. * *

    This query also allows for performing a kNN search subject to a filter. In this case, it first * executes the filter for each leaf, then chooses a strategy dynamically: @@ -100,6 +100,9 @@ protected TopDocs approximateSearch( KnnCollectorManager knnCollectorManager) throws IOException { KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); + if (seedDocs != null) { + knnCollector = new SeededKnnCollector(knnCollector, seedDocs); + } LeafReader reader = context.reader(); ByteVectorValues byteVectorValues = reader.getByteVectorValues(field); if (byteVectorValues == null) { @@ -109,7 +112,7 @@ protected TopDocs approximateSearch( if (Math.min(knnCollector.k(), byteVectorValues.size()) == 0) { return NO_RESULTS; } - reader.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); + reader.searchNearestVectors(field, target, knnCollector, acceptDocs); TopDocs results = knnCollector.topDocs(); return results != null ? results : NO_RESULTS; } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java index 43bac9fbc309..b0aa2e5ce9f3 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java @@ -85,4 +85,9 @@ public interface KnnCollector { * @return The collected top documents */ TopDocs topDocs(); + + /** TODO */ + default DocIdSetIterator getSeedEntryPoints() { + return null; + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 7d0052d1325b..b154421e7a4b 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -31,8 +31,8 @@ import org.apache.lucene.util.VectorUtil; /** - * Uses {@link KnnVectorsReader#search(String, float[], KnnCollector, Bits, DocIdSetIterator)} to - * perform nearest neighbour search. + * Uses {@link KnnVectorsReader#search(String, float[], KnnCollector, Bits)} to perform nearest + * neighbour search. * *

    This query also allows for performing a kNN search subject to a filter. In this case, it first * executes the filter for each leaf, then chooses a strategy dynamically: @@ -101,6 +101,9 @@ protected TopDocs approximateSearch( KnnCollectorManager knnCollectorManager) throws IOException { KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); + if (seedDocs != null) { + knnCollector = new SeededKnnCollector(knnCollector, seedDocs); + } LeafReader reader = context.reader(); FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field); if (floatVectorValues == null) { @@ -110,7 +113,7 @@ protected TopDocs approximateSearch( if (Math.min(knnCollector.k(), floatVectorValues.size()) == 0) { return NO_RESULTS; } - reader.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); + reader.searchNearestVectors(field, target, knnCollector, acceptDocs); TopDocs results = knnCollector.topDocs(); return results != null ? results : NO_RESULTS; } diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnCollector.java new file mode 100644 index 000000000000..319ff0bc3e5b --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnCollector.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +/** A {@link KnnCollector} that provides an initial set of seeds to initialize the search. */ +class SeededKnnCollector implements KnnCollector { + private final KnnCollector collector; + private final DocIdSetIterator seedEntryPoints; + + public SeededKnnCollector(KnnCollector collector, DocIdSetIterator seedEntryPoints) { + this.collector = collector; + this.seedEntryPoints = seedEntryPoints; + } + + @Override + public boolean earlyTerminated() { + return collector.earlyTerminated(); + } + + @Override + public void incVisitedCount(int count) { + collector.incVisitedCount(count); + } + + @Override + public long visitedCount() { + return collector.visitedCount(); + } + + @Override + public long visitLimit() { + return collector.visitLimit(); + } + + @Override + public int k() { + return collector.k(); + } + + @Override + public boolean collect(int docId, float similarity) { + return collector.collect(docId, similarity); + } + + @Override + public float minCompetitiveSimilarity() { + return collector.minCompetitiveSimilarity(); + } + + @Override + public TopDocs topDocs() { + return collector.topDocs(); + } + + @Override + public DocIdSetIterator getSeedEntryPoints() { + return seedEntryPoints; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index 1005500dfe91..b3f400372b9d 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -53,25 +53,6 @@ public HnswGraphSearcher(NeighborQueue candidates, BitSet visited) { this.visited = visited; } - /** - * Searches HNSW graph for the nearest neighbors of a query vector. - * - * @param scorer the scorer to compare the query with the nodes - * @param knnCollector a collector of top knn results to be returned - * @param graph the graph values. May represent the entire graph, or a level in a hierarchical - * graph. - * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or - * {@code null} if they are all allowed to match. - */ - public static void search( - RandomVectorScorer scorer, KnnCollector knnCollector, HnswGraph graph, Bits acceptOrds) - throws IOException { - HnswGraphSearcher graphSearcher = - new HnswGraphSearcher( - new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph))); - search(scorer, knnCollector, graph, graphSearcher, acceptOrds); - } - /** * Searches the HNSW graph for the nearest neighbors of a query vector, starting from the provided * entry points. @@ -82,30 +63,25 @@ public static void search( * graph. * @param acceptOrds {@link Bits} that represents the allowed document ordinals to match, or * {@code null} if they are all allowed to match. - * @param entryPointOrds the entry points for search. */ public static void search( - RandomVectorScorer scorer, - KnnCollector knnCollector, - HnswGraph graph, - Bits acceptOrds, - DocIdSetIterator entryPointOrds) + RandomVectorScorer scorer, KnnCollector knnCollector, HnswGraph graph, Bits acceptOrds) throws IOException { ArrayList entryPointOrdInts = null; - if (entryPointOrds != null) { + DocIdSetIterator entryPoints = knnCollector.getSeedEntryPoints(); + if (entryPoints != null) { entryPointOrdInts = new ArrayList(); int entryPointOrdInt; - while ((entryPointOrdInt = entryPointOrds.nextDoc()) != NO_MORE_DOCS) { + while ((entryPointOrdInt = entryPoints.nextDoc()) != NO_MORE_DOCS) { entryPointOrdInts.add(entryPointOrdInt); } } + HnswGraphSearcher graphSearcher = + new HnswGraphSearcher( + new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph))); if (entryPointOrdInts == null || entryPointOrdInts.isEmpty()) { - search(scorer, knnCollector, graph, acceptOrds); + search(scorer, knnCollector, graph, graphSearcher, acceptOrds); } else { - HnswGraphSearcher graphSearcher = - new HnswGraphSearcher( - new NeighborQueue(knnCollector.k(), true), - new SparseFixedBitSet(getGraphSize(graph))); int[] entryPointOrdIntsArr = entryPointOrdInts.stream().mapToInt(Integer::intValue).toArray(); // We use provided entry point ordinals to search the complete graph (level 0) graphSearcher.searchLevel( diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java index d053ab212cdb..825de3ab725a 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99HnswQuantizedVectorsFormat.java @@ -234,7 +234,7 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { try (IndexReader reader = DirectoryReader.open(w)) { LeafReader r = getOnlyLeafReader(reader); TopKnnCollector topKnnCollector = new TopKnnCollector(5, Integer.MAX_VALUE); - r.searchNearestVectors("f", new float[] {0.6f, 0.8f}, topKnnCollector, null, null); + r.searchNearestVectors("f", new float[] {0.6f, 0.8f}, topKnnCollector, null); TopDocs topDocs = topKnnCollector.topDocs(); assertEquals(3, topDocs.totalHits.value()); for (ScoreDoc scoreDoc : topDocs.scoreDocs) { diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java index 286026fe960c..a0f640fa650b 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorScorer.java @@ -299,7 +299,7 @@ private void testSingleVectorPerSegment(VectorSimilarityFunction sim) throws IOE LeafReader leafReader = getOnlyLeafReader(reader); StoredFields storedFields = reader.storedFields(); float[] queryVector = new float[] {0.6f, 0.8f}; - var hits = leafReader.searchNearestVectors("field", queryVector, 3, null, null, 100); + var hits = leafReader.searchNearestVectors("field", queryVector, 3, null, 100); assertEquals(hits.scoreDocs.length, 3); assertEquals("B", storedFields.document(hits.scoreDocs[0].doc).get("id")); assertEquals("A", storedFields.document(hits.scoreDocs[1].doc).get("id")); diff --git a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java index be0733ef5009..64df927c7650 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/lucene99/TestLucene99ScalarQuantizedVectorsFormat.java @@ -93,7 +93,7 @@ public void testSearch() throws Exception { KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); // if this search found any results it would raise NPE attempting to collect them in our // null collector - knnVectorsReader.search("f", new float[] {1, 0}, null, null, null); + knnVectorsReader.search("f", new float[] {1, 0}, null, null); } else { fail("reader is not CodecReader"); } diff --git a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java index 0f40b22f9cd7..45814144d10a 100644 --- a/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java +++ b/lucene/core/src/test/org/apache/lucene/codecs/perfield/TestPerFieldKnnVectorsFormat.java @@ -95,12 +95,11 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { new float[] {1, 2, 3}, 10, reader.getLiveDocs(), - null, Integer.MAX_VALUE); assertEquals(0, hits.scoreDocs.length); hits = reader.searchNearestVectors( - "id", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), null, Integer.MAX_VALUE); + "id", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE); assertEquals(0, hits.scoreDocs.length); } } @@ -147,12 +146,12 @@ public KnnVectorsFormat getKnnVectorsFormatForField(String field) { LeafReader reader = ireader.leaves().get(0).reader(); TopDocs hits1 = reader.searchNearestVectors( - "field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), null, Integer.MAX_VALUE); + "field1", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE); assertEquals(1, hits1.scoreDocs.length); TopDocs hits2 = reader.searchNearestVectors( - "field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), null, Integer.MAX_VALUE); + "field2", new float[] {1, 2, 3}, 10, reader.getLiveDocs(), Integer.MAX_VALUE); assertEquals(1, hits2.scoreDocs.length); } } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java index ddfb2ddcc438..3c82cd6b33e4 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestExitableDirectoryReader.java @@ -471,7 +471,6 @@ public void testFloatVectorValues() throws IOException { TestVectorUtil.randomVector(dimension), 5, leaf.getLiveDocs(), - null, Integer.MAX_VALUE)); } else { DocIdSetIterator iter = leaf.getFloatVectorValues("vector"); @@ -482,7 +481,6 @@ public void testFloatVectorValues() throws IOException { TestVectorUtil.randomVector(dimension), 5, leaf.getLiveDocs(), - null, Integer.MAX_VALUE); } @@ -548,7 +546,6 @@ public void testByteVectorValues() throws IOException { TestVectorUtil.randomVectorBytes(dimension), 5, leaf.getLiveDocs(), - null, Integer.MAX_VALUE)); } else { @@ -560,7 +557,6 @@ public void testByteVectorValues() throws IOException { TestVectorUtil.randomVectorBytes(dimension), 5, leaf.getLiveDocs(), - null, Integer.MAX_VALUE); } diff --git a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java index 553e72c66b10..72be0bd929fa 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestKnnGraph.java @@ -359,7 +359,7 @@ private static TopDocs doKnnSearch(IndexReader reader, float[] vector, int k) th Bits liveDocs = ctx.reader().getLiveDocs(); results[ctx.ord] = ctx.reader() - .searchNearestVectors(KNN_GRAPH_FIELD, vector, k, liveDocs, null, Integer.MAX_VALUE); + .searchNearestVectors(KNN_GRAPH_FIELD, vector, k, liveDocs, Integer.MAX_VALUE); if (ctx.docBase > 0) { for (ScoreDoc doc : results[ctx.ord].scoreDocs) { doc.doc += ctx.docBase; diff --git a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java index 1b7c49681092..f3016d4b82f3 100644 --- a/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java +++ b/lucene/core/src/test/org/apache/lucene/index/TestSegmentToThreadMapping.java @@ -23,7 +23,6 @@ import java.util.Collections; import java.util.List; import org.apache.lucene.document.Document; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.store.Directory; @@ -119,19 +118,11 @@ public ByteVectorValues getByteVectorValues(String field) { @Override public void searchNearestVectors( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) {} + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override public void searchNearestVectors( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) {} + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override protected void doClose() {} diff --git a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java index 6843a07b66eb..cb8c71a089f0 100644 --- a/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java +++ b/lucene/highlighter/src/java/org/apache/lucene/search/highlight/TermVectorLeafReader.java @@ -42,7 +42,6 @@ import org.apache.lucene.index.Terms; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; import org.apache.lucene.util.Version; @@ -182,19 +181,11 @@ public ByteVectorValues getByteVectorValues(String fieldName) { @Override public void searchNearestVectors( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) {} + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override public void searchNearestVectors( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) {} + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override public void checkIntegrity() throws IOException {} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index cc0539c68aea..63a8086dae3b 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -149,7 +149,7 @@ protected TopDocs approximateSearch( if (collector == null) { return NO_RESULTS; } - context.reader().searchNearestVectors(field, query, collector, acceptDocs, seedDocs); + context.reader().searchNearestVectors(field, query, collector, acceptDocs); return collector.topDocs(); } diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index 4bb584a9e9a5..4c8195af4e7b 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -148,7 +148,7 @@ protected TopDocs approximateSearch( if (collector == null) { return NO_RESULTS; } - context.reader().searchNearestVectors(field, query, collector, acceptDocs, seedDocs); + context.reader().searchNearestVectors(field, query, collector, acceptDocs); return collector.topDocs(); } diff --git a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java index 2acd8fb4ab73..2d46b243d838 100644 --- a/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java +++ b/lucene/memory/src/java/org/apache/lucene/index/memory/MemoryIndex.java @@ -1729,19 +1729,11 @@ public ByteVectorValues getByteVectorValues(String fieldName) { @Override public void searchNearestVectors( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) {} + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override public void searchNearestVectors( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) {} + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override public void checkIntegrity() throws IOException { diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java index b48135d52571..501e2e5616f0 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/codecs/asserting/AssertingKnnVectorsFormat.java @@ -32,7 +32,6 @@ import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.Sorter; import org.apache.lucene.index.VectorEncoding; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.util.Bits; @@ -147,33 +146,23 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { } @Override - public void search( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldInfo fi = fis.fieldInfo(field); assert fi != null && fi.getVectorDimension() > 0 && fi.getVectorEncoding() == VectorEncoding.FLOAT32; - delegate.search(field, target, knnCollector, acceptDocs, seedDocs); + delegate.search(field, target, knnCollector, acceptDocs); } @Override - public void search( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { FieldInfo fi = fis.fieldInfo(field); assert fi != null && fi.getVectorDimension() > 0 && fi.getVectorEncoding() == VectorEncoding.BYTE; - delegate.search(field, target, knnCollector, acceptDocs, seedDocs); + delegate.search(field, target, knnCollector, acceptDocs); } @Override diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java index fe5f084c4963..63fe2b8f4c11 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/BaseKnnVectorsFormatTestCase.java @@ -802,12 +802,7 @@ public void testDeleteAllVectorDocs() throws Exception { // assert that knn search doesn't fail on a field with all deleted docs TopDocs results = leafReader.searchNearestVectors( - "v", - randomNormalizedVector(4), - 1, - leafReader.getLiveDocs(), - null, - Integer.MAX_VALUE); + "v", randomNormalizedVector(4), 1, leafReader.getLiveDocs(), Integer.MAX_VALUE); assertEquals(0, results.scoreDocs.length); } } @@ -1444,12 +1439,7 @@ public void testSearchWithVisitedLimit() throws Exception { TopDocs results = ctx.reader() .searchNearestVectors( - fieldName, - randomNormalizedVector(dimension), - k, - liveDocs, - null, - visitedLimit); + fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit); assertEquals(TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO, results.totalHits.relation()); assertEquals(visitedLimit, results.totalHits.value()); @@ -1459,12 +1449,7 @@ public void testSearchWithVisitedLimit() throws Exception { results = ctx.reader() .searchNearestVectors( - fieldName, - randomNormalizedVector(dimension), - k, - liveDocs, - null, - visitedLimit); + fieldName, randomNormalizedVector(dimension), k, liveDocs, visitedLimit); assertEquals(TotalHits.Relation.EQUAL_TO, results.totalHits.relation()); assertTrue(results.totalHits.value() <= visitedLimit); } @@ -1543,12 +1528,7 @@ public void testRandomWithUpdatesAndGraph() throws Exception { TopDocs results = ctx.reader() .searchNearestVectors( - fieldName, - randomNormalizedVector(dimension), - k, - liveDocs, - null, - Integer.MAX_VALUE); + fieldName, randomNormalizedVector(dimension), k, liveDocs, Integer.MAX_VALUE); assertEquals(Math.min(k, size), results.scoreDocs.length); for (int i = 0; i < k - 1; i++) { assertTrue(results.scoreDocs[i].score >= results.scoreDocs[i + 1].score); diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java index b2d215280ff5..3fee110f7836 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/index/MergeReaderWrapper.java @@ -41,7 +41,6 @@ import org.apache.lucene.index.StoredFields; import org.apache.lucene.index.TermVectors; import org.apache.lucene.index.Terms; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.util.Bits; @@ -242,24 +241,14 @@ public ByteVectorValues getByteVectorValues(String fieldName) throws IOException @Override public void searchNearestVectors( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException { - in.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override public void searchNearestVectors( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) - throws IOException { - in.searchNearestVectors(field, target, knnCollector, acceptDocs, seedDocs); + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + in.searchNearestVectors(field, target, knnCollector, acceptDocs); } @Override diff --git a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java index 36030f072e56..efd13121d930 100644 --- a/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java +++ b/lucene/test-framework/src/java/org/apache/lucene/tests/search/QueryUtils.java @@ -245,19 +245,11 @@ public ByteVectorValues getByteVectorValues(String field) throws IOException { @Override public void searchNearestVectors( - String field, - float[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) {} + String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override public void searchNearestVectors( - String field, - byte[] target, - KnnCollector knnCollector, - Bits acceptDocs, - DocIdSetIterator seedDocs) {} + String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) {} @Override public FieldInfos getFieldInfos() { From e8417d38fbee39588f3e274e865faa30c2de27b8 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 2 Oct 2024 10:47:01 +0100 Subject: [PATCH 27/41] consistent naming --- .../backward_codecs/lucene91/Lucene91HnswVectorsReader.java | 2 +- .../java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java index d73e194a24dc..81f8d97a9a0c 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene91/Lucene91HnswVectorsReader.java @@ -250,7 +250,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocsBits) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java index 8d6f66c54226..9d776567883e 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/hnsw/FlatVectorsReader.java @@ -62,7 +62,7 @@ public void search(String field, float[] target, KnnCollector knnCollector, Bits } @Override - public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDoc) + public void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { // don't scan stored field data. If we didn't index it, produce no search results } From 440b0d0edfa08e75cc30a0561872c91b5fecdfcd Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 2 Oct 2024 10:56:01 +0100 Subject: [PATCH 28/41] javadoc --- .../java/org/apache/lucene/search/KnnCollector.java | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java index b0aa2e5ce9f3..2312364d19ed 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java @@ -86,7 +86,16 @@ public interface KnnCollector { */ TopDocs topDocs(); - /** TODO */ + /** + * This method returns a {@link DocIdSetIterator} over entry points that seed the KNN search, + * or {@code null} (default) to perform a full KNN search (without seeds). + * + *

    Note that the entry points should represent ordinals, rather than true document IDs. + * + * @return the seed entry points or {@code null}. + * + * @lucene.experimental + */ default DocIdSetIterator getSeedEntryPoints() { return null; } From 87e75abfe6e76a26fb87d4de6518e27c2f5927d8 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 2 Oct 2024 10:58:05 +0100 Subject: [PATCH 29/41] tidy --- .../src/java/org/apache/lucene/search/KnnCollector.java | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java index 2312364d19ed..0850ba82da28 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java @@ -87,13 +87,12 @@ public interface KnnCollector { TopDocs topDocs(); /** - * This method returns a {@link DocIdSetIterator} over entry points that seed the KNN search, - * or {@code null} (default) to perform a full KNN search (without seeds). - * + * This method returns a {@link DocIdSetIterator} over entry points that seed the KNN search, or + * {@code null} (default) to perform a full KNN search (without seeds). + * *

    Note that the entry points should represent ordinals, rather than true document IDs. * * @return the seed entry points or {@code null}. - * * @lucene.experimental */ default DocIdSetIterator getSeedEntryPoints() { From 7b3350f92daba03e6f71ee4c5af7e33694f3c248 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 2 Oct 2024 11:06:10 +0100 Subject: [PATCH 30/41] javadoc typo --- .../java/org/apache/lucene/search/AbstractKnnVectorQuery.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index e6ff0a931851..d01415d54467 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -51,7 +51,7 @@ * * *

    When a seed query is provided, this query is executed first to seed the kNN search (subject to - * the same rules about the filter). If the seed query fails to identify any documents, it falls + * the same rules provided by the filter). If the seed query fails to identify any documents, it falls * back on the strategy above. */ abstract class AbstractKnnVectorQuery extends Query { From 5bb40c220d578ced04d655a5a53e7f425f67b80e Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 2 Oct 2024 11:21:15 +0100 Subject: [PATCH 31/41] javadoc --- .../lucene/search/AbstractKnnVectorQuery.java | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index d01415d54467..3822f5f1e92d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -51,8 +51,8 @@ * * *

    When a seed query is provided, this query is executed first to seed the kNN search (subject to - * the same rules provided by the filter). If the seed query fails to identify any documents, it falls - * back on the strategy above. + * the same rules provided by the filter). If the seed query fails to identify any documents, it + * falls back on the strategy above. */ abstract class AbstractKnnVectorQuery extends Query { @@ -208,9 +208,9 @@ private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeig false /* supportsConcurrency */) .newCollector(); final LeafReader leafReader = ctx.reader(); - try { - final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx); - if (leafCollector != null) { + final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx); + if (leafCollector != null) { + try { BulkScorer scorer = seedWeight.bulkScorer(ctx); if (scorer != null) { scorer.score( @@ -220,10 +220,10 @@ private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeig DocIdSetIterator.NO_MORE_DOCS /* max */); } leafCollector.finish(); + } catch ( + @SuppressWarnings("unused") + CollectionTerminatedException e) { } - } catch ( - @SuppressWarnings("unused") - CollectionTerminatedException e) { } TopDocs seedTopDocs = seedCollector.topDocs(); From 216bfc41846ae12a558291f20c0b12d3639641dc Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 2 Oct 2024 11:21:30 +0100 Subject: [PATCH 32/41] test fixes --- .../lucene/document/TestManyKnnDocs.java | 5 ++-- .../search/BaseKnnVectorQueryTestCase.java | 29 +++++++++---------- 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java index 05adc70596f9..92d6562dd3f3 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java @@ -54,17 +54,16 @@ public static void init_index() throws Exception { mp.setMaxMergeAtOnce(256); // avoid intermediate merges (waste of time with HNSW?) mp.setSegmentsPerTier(256); // only merge once at the end when we ask iwc.setMergePolicy(mp); - VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.EUCLIDEAN; + VectorSimilarityFunction similarityFunction = VectorSimilarityFunction.DOT_PRODUCT; try (Directory dir = FSDirectory.open(testDir = createTempDir("ManyKnnVectorDocs")); IndexWriter iw = new IndexWriter(dir, iwc)) { int numVectors = 2088992; for (int i = 0; i < numVectors; i++) { - float[] vector = new float[128]; + float[] vector = new float[1]; Document doc = new Document(); vector[0] = (i % 256); - vector[1] = (float) (i / 256.); doc.add(new KnnFloatVectorField("field", vector, similarityFunction)); doc.add(new KeywordField("int", "" + i, org.apache.lucene.document.Field.Store.YES)); doc.add(new StoredField("intValue", i)); diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index a5325a009d09..5bf1b9e25d0d 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -619,7 +619,8 @@ public void testRandomWithSeed() throws IOException { int numDocs = 1000; int dimension = atLeast(5); int numIters = atLeast(10); - try (Directory d = newDirectory()) { + int numDocsWithVector = 0; + try (Directory d = newDirectoryForTest()) { // Always use the default kNN format to have predictable behavior around when it hits // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN // format @@ -628,10 +629,10 @@ public void testRandomWithSeed() throws IOException { RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); for (int i = 0; i < numDocs; i++) { Document doc = new Document(); - if (random() - .nextBoolean()) { // Randomly skip some vectors to test the mapping from docid to - // ordinals + if (random().nextBoolean()) { + // Randomly skip some vectors to test the mapping from docid to ordinals doc.add(getKnnVectorField("field", randomVector(dimension))); + numDocsWithVector += 1; } doc.add(new NumericDocValuesField("tag", i)); doc.add(new IntPoint("tag", i)); @@ -645,16 +646,18 @@ public void testRandomWithSeed() throws IOException { for (int i = 0; i < numIters; i++) { int k = random().nextInt(80) + 1; int n = random().nextInt(100) + 1; + // we may get fewer results than requested if there are deletions, but this test doesn't + // check that + assert reader.hasDeletions() == false; // All documents as seeds Query seed1 = new MatchAllDocsQuery(); + Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); AbstractKnnVectorQuery query = - getKnnVectorQuery("field", randomVector(dimension), k, null, seed1); + getKnnVectorQuery("field", randomVector(dimension), k, filter, seed1); TopDocs results = searcher.search(query, n); - int expected = Math.min(Math.min(n, k), reader.numDocs()); - // we may get fewer results than requested if there are deletions, but this test doesn't - // test that - assert reader.hasDeletions() == false; + int expected = Math.min(Math.min(n, k), numDocsWithVector); + assertEquals(expected, results.scoreDocs.length); assertTrue(results.totalHits.value() >= results.scoreDocs.length); // verify the results are in descending score order @@ -669,9 +672,6 @@ public void testRandomWithSeed() throws IOException { query = getKnnVectorQuery("field", randomVector(dimension), k, null, seed2); results = searcher.search(query, n); expected = Math.min(Math.min(n, k), reader.numDocs()); - // we may get fewer results than requested if there are deletions, but this test doesn't - // test that - assert reader.hasDeletions() == false; assertEquals(expected, results.scoreDocs.length); assertTrue(results.totalHits.value() >= results.scoreDocs.length); // verify the results are in descending score order @@ -682,13 +682,10 @@ public void testRandomWithSeed() throws IOException { } // No seed documents -- falls back on full approx search - Query seed3 = new BooleanQuery.Builder().build(); + Query seed3 = new MatchNoDocsQuery(); query = getKnnVectorQuery("field", randomVector(dimension), k, null, seed3); results = searcher.search(query, n); expected = Math.min(Math.min(n, k), reader.numDocs()); - // we may get fewer results than requested if there are deletions, but this test doesn't - // test that - assert reader.hasDeletions() == false; assertEquals(expected, results.scoreDocs.length); assertTrue(results.totalHits.value() >= results.scoreDocs.length); // verify the results are in descending score order From b6725c76ee2c71b58bd0c842616a0fc78f46a728 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 2 Oct 2024 11:57:10 +0100 Subject: [PATCH 33/41] merge resolution --- .../java/org/apache/lucene/index/ByteVectorValues.java | 9 +-------- .../org/apache/lucene/index/FloatVectorValues.java | 10 +--------- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index 70fe952f18e3..0368b723ad53 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.List; import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; /** @@ -64,14 +65,6 @@ public static void checkField(LeafReader in, String field) { } } - /** - * Return a {@link VectorScorer} for the given query vector. - * - * @param query the query vector - * @return a {@link VectorScorer} instance or null - */ - public abstract VectorScorer scorer(byte[] query) throws IOException; - /** * Returns a new iterator that maps the provided docIds to the vector ordinals. * diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 3e37b7a19c9d..279b2774b512 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.util.List; import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; /** @@ -64,15 +65,6 @@ public static void checkField(LeafReader in, String field) { } } - /** - * Return a {@link VectorScorer} for the given query vector and the current {@link - * FloatVectorValues}. - * - * @param target the query vector - * @return a {@link VectorScorer} instance or null - */ - public abstract VectorScorer scorer(float[] query) throws IOException; - /** * Returns a new iterator that maps the provided docIds to the vector ordinals. * From 0cfd99b5e163c2ea7296420c0f1cad0096e0fd67 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 2 Oct 2024 12:08:04 +0100 Subject: [PATCH 34/41] merging --- .../src/java/org/apache/lucene/index/ByteVectorValues.java | 6 ++++++ .../java/org/apache/lucene/index/FloatVectorValues.java | 7 +++++++ 2 files changed, 13 insertions(+) diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index 0368b723ad53..aefaa7fa684f 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -77,6 +77,12 @@ public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { return docIds; } + /** + * Return a {@link VectorScorer} for the given query vector. + * + * @param query the query vector + * @return a {@link VectorScorer} instance or null + */ public VectorScorer scorer(byte[] query) throws IOException { throw new UnsupportedOperationException(); } diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 279b2774b512..9467f8710a94 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -77,6 +77,13 @@ public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { return docIds; } + /** + * Return a {@link VectorScorer} for the given query vector and the current {@link + * FloatVectorValues}. + * + * @param target the query vector + * @return a {@link VectorScorer} instance or null + */ public VectorScorer scorer(float[] target) throws IOException { throw new UnsupportedOperationException(); } From dd63bb0cb1bc82d352edf67a503d27e7cd59e16a Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 2 Oct 2024 13:58:33 +0100 Subject: [PATCH 35/41] refactor as decorator --- .../lucene/search/KnnByteVectorQuery.java | 2 +- .../apache/lucene/search/KnnCollector.java | 67 +++++++++++++++++ .../lucene/search/KnnFloatVectorQuery.java | 2 +- .../lucene/search/SeededKnnCollector.java | 73 ------------------- 4 files changed, 69 insertions(+), 75 deletions(-) delete mode 100644 lucene/core/src/java/org/apache/lucene/search/SeededKnnCollector.java diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index 35e8bcd2a955..0a3fb83f6cef 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -101,7 +101,7 @@ protected TopDocs approximateSearch( throws IOException { KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); if (seedDocs != null) { - knnCollector = new SeededKnnCollector(knnCollector, seedDocs); + knnCollector = new KnnCollector.Seeded(knnCollector, seedDocs); } LeafReader reader = context.reader(); ByteVectorValues byteVectorValues = reader.getByteVectorValues(field); diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java index 0850ba82da28..be547c3aa3dc 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java @@ -98,4 +98,71 @@ public interface KnnCollector { default DocIdSetIterator getSeedEntryPoints() { return null; } + + public abstract static class Decorator implements KnnCollector { + private KnnCollector collector; + + public Decorator(KnnCollector collector) { + this.collector = collector; + } + + @Override + public boolean earlyTerminated() { + return collector.earlyTerminated(); + } + + @Override + public void incVisitedCount(int count) { + collector.incVisitedCount(count); + } + + @Override + public long visitedCount() { + return collector.visitedCount(); + } + + @Override + public long visitLimit() { + return collector.visitLimit(); + } + + @Override + public int k() { + return collector.k(); + } + + @Override + public boolean collect(int docId, float similarity) { + return collector.collect(docId, similarity); + } + + @Override + public float minCompetitiveSimilarity() { + return collector.minCompetitiveSimilarity(); + } + + @Override + public TopDocs topDocs() { + return collector.topDocs(); + } + + @Override + public DocIdSetIterator getSeedEntryPoints() { + return collector.getSeedEntryPoints(); + } + } + + public static class Seeded extends Decorator { + private DocIdSetIterator seedEntryPoints; + + public Seeded(KnnCollector collector, DocIdSetIterator seedEntryPoints) { + super(collector); + this.seedEntryPoints = seedEntryPoints; + } + + @Override + public DocIdSetIterator getSeedEntryPoints() { + return seedEntryPoints; + } + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index b154421e7a4b..e045db37dd00 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -102,7 +102,7 @@ protected TopDocs approximateSearch( throws IOException { KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); if (seedDocs != null) { - knnCollector = new SeededKnnCollector(knnCollector, seedDocs); + knnCollector = new KnnCollector.Seeded(knnCollector, seedDocs); } LeafReader reader = context.reader(); FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field); diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnCollector.java deleted file mode 100644 index 319ff0bc3e5b..000000000000 --- a/lucene/core/src/java/org/apache/lucene/search/SeededKnnCollector.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.lucene.search; - -/** A {@link KnnCollector} that provides an initial set of seeds to initialize the search. */ -class SeededKnnCollector implements KnnCollector { - private final KnnCollector collector; - private final DocIdSetIterator seedEntryPoints; - - public SeededKnnCollector(KnnCollector collector, DocIdSetIterator seedEntryPoints) { - this.collector = collector; - this.seedEntryPoints = seedEntryPoints; - } - - @Override - public boolean earlyTerminated() { - return collector.earlyTerminated(); - } - - @Override - public void incVisitedCount(int count) { - collector.incVisitedCount(count); - } - - @Override - public long visitedCount() { - return collector.visitedCount(); - } - - @Override - public long visitLimit() { - return collector.visitLimit(); - } - - @Override - public int k() { - return collector.k(); - } - - @Override - public boolean collect(int docId, float similarity) { - return collector.collect(docId, similarity); - } - - @Override - public float minCompetitiveSimilarity() { - return collector.minCompetitiveSimilarity(); - } - - @Override - public TopDocs topDocs() { - return collector.topDocs(); - } - - @Override - public DocIdSetIterator getSeedEntryPoints() { - return seedEntryPoints; - } -} From 16358701f141e3a2363cbe6ed0b9282dd644c128 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 2 Oct 2024 15:03:07 +0100 Subject: [PATCH 36/41] javadoc --- .../java/org/apache/lucene/search/KnnCollector.java | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java index be547c3aa3dc..43338fec72e3 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java @@ -99,6 +99,12 @@ default DocIdSetIterator getSeedEntryPoints() { return null; } + /** + * KnnCollector.Decorator is the base class for decorators of KnnCollector objects, which extend + * the object with new behaviors. + * + * @lucene.experimental + */ public abstract static class Decorator implements KnnCollector { private KnnCollector collector; @@ -152,6 +158,11 @@ public DocIdSetIterator getSeedEntryPoints() { } } + /** + * KnnCollector.Seeded is a KnnCollector decorator that replaces the seedEntryPoints. + * + * @lucene.experimental + */ public static class Seeded extends Decorator { private DocIdSetIterator seedEntryPoints; From c8a512aa9054fa3436fcc821d99a5ded1e9f0695 Mon Sep 17 00:00:00 2001 From: Sean MacAvaney Date: Wed, 2 Oct 2024 15:15:16 +0100 Subject: [PATCH 37/41] apply decorator elsewhere --- .../TimeLimitingKnnCollectorManager.java | 42 +++---------------- .../hnsw/OrdinalTranslatedKnnCollector.java | 42 +++---------------- 2 files changed, 11 insertions(+), 73 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java index 2a1f312fbc58..2dc2f035b90f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java +++ b/lucene/core/src/java/org/apache/lucene/search/TimeLimitingKnnCollectorManager.java @@ -45,51 +45,19 @@ public KnnCollector newCollector(int visitedLimit, LeafReaderContext context) th return new TimeLimitingKnnCollector(collector); } - class TimeLimitingKnnCollector implements KnnCollector { - private final KnnCollector collector; - - TimeLimitingKnnCollector(KnnCollector collector) { - this.collector = collector; + class TimeLimitingKnnCollector extends KnnCollector.Decorator { + public TimeLimitingKnnCollector(KnnCollector collector) { + super(collector); } @Override public boolean earlyTerminated() { - return queryTimeout.shouldExit() || collector.earlyTerminated(); - } - - @Override - public void incVisitedCount(int count) { - collector.incVisitedCount(count); - } - - @Override - public long visitedCount() { - return collector.visitedCount(); - } - - @Override - public long visitLimit() { - return collector.visitLimit(); - } - - @Override - public int k() { - return collector.k(); - } - - @Override - public boolean collect(int docId, float similarity) { - return collector.collect(docId, similarity); - } - - @Override - public float minCompetitiveSimilarity() { - return collector.minCompetitiveSimilarity(); + return queryTimeout.shouldExit() || super.earlyTerminated(); } @Override public TopDocs topDocs() { - TopDocs docs = collector.topDocs(); + TopDocs docs = super.topDocs(); // Mark results as partial if timeout is met TotalHits.Relation relation = diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java index ed1a5ffb59fa..5225fe700ab9 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/OrdinalTranslatedKnnCollector.java @@ -24,54 +24,24 @@ /** * Wraps a provided KnnCollector object, translating the provided vectorId ordinal to a documentId */ -public final class OrdinalTranslatedKnnCollector implements KnnCollector { +public final class OrdinalTranslatedKnnCollector extends KnnCollector.Decorator { - private final KnnCollector in; private final IntToIntFunction vectorOrdinalToDocId; - public OrdinalTranslatedKnnCollector(KnnCollector in, IntToIntFunction vectorOrdinalToDocId) { - this.in = in; + public OrdinalTranslatedKnnCollector( + KnnCollector collector, IntToIntFunction vectorOrdinalToDocId) { + super(collector); this.vectorOrdinalToDocId = vectorOrdinalToDocId; } - @Override - public boolean earlyTerminated() { - return in.earlyTerminated(); - } - - @Override - public void incVisitedCount(int count) { - in.incVisitedCount(count); - } - - @Override - public long visitedCount() { - return in.visitedCount(); - } - - @Override - public long visitLimit() { - return in.visitLimit(); - } - - @Override - public int k() { - return in.k(); - } - @Override public boolean collect(int vectorId, float similarity) { - return in.collect(vectorOrdinalToDocId.apply(vectorId), similarity); - } - - @Override - public float minCompetitiveSimilarity() { - return in.minCompetitiveSimilarity(); + return super.collect(vectorOrdinalToDocId.apply(vectorId), similarity); } @Override public TopDocs topDocs() { - TopDocs td = in.topDocs(); + TopDocs td = super.topDocs(); return new TopDocs( new TotalHits( visitedCount(), From cb6ab5484eeada6a0f79ad51079c0742a8d32a47 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 20 Dec 2024 11:34:00 -0500 Subject: [PATCH 38/41] Refactor --- .../lucene/search/AbstractKnnVectorQuery.java | 148 +------------- .../lucene/search/KnnByteVectorQuery.java | 37 +--- .../apache/lucene/search/KnnCollector.java | 39 +--- .../lucene/search/KnnFloatVectorQuery.java | 37 +--- .../search/SeededKnnByteVectorQuery.java | 85 ++++++++ .../search/SeededKnnFloatVectorQuery.java | 85 ++++++++ .../lucene/search/knn/EntryPointProvider.java | 25 +++ .../lucene/search/knn/SeededKnnCollector.java | 34 ++++ .../search/knn/SeededKnnCollectorManager.java | 174 +++++++++++++++++ .../lucene/util/hnsw/HnswGraphSearcher.java | 28 ++- .../lucene/document/TestManyKnnDocs.java | 9 +- .../search/BaseKnnVectorQueryTestCase.java | 91 --------- .../lucene/search/TestKnnByteVectorQuery.java | 22 +-- .../search/TestKnnFloatVectorQuery.java | 20 +- .../search/TestSeededKnnByteVectorQuery.java | 181 ++++++++++++++++++ .../search/TestSeededKnnFloatVectorQuery.java | 167 ++++++++++++++++ ...iversifyingChildrenByteKnnVectorQuery.java | 1 - ...versifyingChildrenFloatKnnVectorQuery.java | 1 - 18 files changed, 790 insertions(+), 394 deletions(-) create mode 100644 lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java create mode 100644 lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java create mode 100644 lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java create mode 100644 lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java create mode 100644 lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java create mode 100644 lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java diff --git a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java index 3822f5f1e92d..e9246a8b5756 100644 --- a/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/AbstractKnnVectorQuery.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.Comparator; import java.util.List; import java.util.Objects; @@ -49,10 +48,6 @@ *

  • Otherwise run a kNN search subject to the filter *
  • If the kNN search visits too many vectors without completing, stop and run an exact search * - * - *

    When a seed query is provided, this query is executed first to seed the kNN search (subject to - * the same rules provided by the filter). If the seed query fails to identify any documents, it - * falls back on the strategy above. */ abstract class AbstractKnnVectorQuery extends Query { @@ -60,21 +55,15 @@ abstract class AbstractKnnVectorQuery extends Query { protected final String field; protected final int k; - private final Query filter; - private final Query seed; + protected final Query filter; public AbstractKnnVectorQuery(String field, int k, Query filter) { - this(field, k, filter, null); - } - - public AbstractKnnVectorQuery(String field, int k, Query filter, Query seed) { this.field = Objects.requireNonNull(field, "field"); this.k = k; if (k < 1) { throw new IllegalArgumentException("k must be at least 1, got: " + k); } this.filter = filter; - this.seed = seed; } @Override @@ -94,21 +83,6 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { filterWeight = null; } - final Weight seedWeight; - if (seed != null) { - BooleanQuery.Builder booleanSeedQueryBuilder = - new BooleanQuery.Builder() - .add(seed, BooleanClause.Occur.MUST) - .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); - if (filter != null) { - booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); - } - Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); - seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); - } else { - seedWeight = null; - } - TimeLimitingKnnCollectorManager knnCollectorManager = new TimeLimitingKnnCollectorManager( getKnnCollectorManager(k, indexSearcher), indexSearcher.getTimeout()); @@ -116,7 +90,7 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { List leafReaderContexts = reader.leaves(); List> tasks = new ArrayList<>(leafReaderContexts.size()); for (LeafReaderContext context : leafReaderContexts) { - tasks.add(() -> searchLeaf(context, filterWeight, seedWeight, knnCollectorManager)); + tasks.add(() -> searchLeaf(context, filterWeight, knnCollectorManager)); } TopDocs[] perLeafResults = taskExecutor.invokeAll(tasks).toArray(TopDocs[]::new); @@ -131,11 +105,9 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException { private TopDocs searchLeaf( LeafReaderContext ctx, Weight filterWeight, - Weight seedWeight, TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) throws IOException { - TopDocs results = - getLeafResults(ctx, filterWeight, seedWeight, timeLimitingKnnCollectorManager); + TopDocs results = getLeafResults(ctx, filterWeight, timeLimitingKnnCollectorManager); if (ctx.docBase > 0) { for (ScoreDoc scoreDoc : results.scoreDocs) { scoreDoc.doc += ctx.docBase; @@ -147,19 +119,13 @@ private TopDocs searchLeaf( private TopDocs getLeafResults( LeafReaderContext ctx, Weight filterWeight, - Weight seedWeight, TimeLimitingKnnCollectorManager timeLimitingKnnCollectorManager) throws IOException { final LeafReader reader = ctx.reader(); final Bits liveDocs = reader.getLiveDocs(); if (filterWeight == null) { - return approximateSearch( - ctx, - liveDocs, - executeSeedQuery(ctx, seedWeight), - Integer.MAX_VALUE, - timeLimitingKnnCollectorManager); + return approximateSearch(ctx, liveDocs, Integer.MAX_VALUE, timeLimitingKnnCollectorManager); } Scorer scorer = filterWeight.scorer(ctx); @@ -179,13 +145,7 @@ private TopDocs getLeafResults( // Perform the approximate kNN search // We pass cost + 1 here to account for the edge case when we explore exactly cost vectors - TopDocs results = - approximateSearch( - ctx, - acceptDocs, - executeSeedQuery(ctx, seedWeight), - cost + 1, - timeLimitingKnnCollectorManager); + TopDocs results = approximateSearch(ctx, acceptDocs, cost + 1, timeLimitingKnnCollectorManager); if (results.totalHits.relation() == TotalHits.Relation.EQUAL_TO // Return partial results only when timeout is met || (queryTimeout != null && queryTimeout.shouldExit())) { @@ -196,49 +156,6 @@ private TopDocs getLeafResults( } } - private DocIdSetIterator executeSeedQuery(LeafReaderContext ctx, Weight seedWeight) - throws IOException { - if (seedWeight == null) return null; - // Execute the seed query - TopScoreDocCollector seedCollector = - new TopScoreDocCollectorManager( - k /* numHits */, - null /* after */, - Integer.MAX_VALUE /* totalHitsThreshold */, - false /* supportsConcurrency */) - .newCollector(); - final LeafReader leafReader = ctx.reader(); - final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx); - if (leafCollector != null) { - try { - BulkScorer scorer = seedWeight.bulkScorer(ctx); - if (scorer != null) { - scorer.score( - leafCollector, - leafReader.getLiveDocs(), - 0 /* min */, - DocIdSetIterator.NO_MORE_DOCS /* max */); - } - leafCollector.finish(); - } catch ( - @SuppressWarnings("unused") - CollectionTerminatedException e) { - } - } - - TopDocs seedTopDocs = seedCollector.topDocs(); - return convertDocIdsToVectorOrdinals(leafReader, new TopDocsDISI(seedTopDocs)); - } - - /** - * Returns a new iterator that maps the provided docIds to the vector ordinals. - * - * @lucene.internal - * @lucene.experimental - */ - protected abstract DocIdSetIterator convertDocIdsToVectorOrdinals( - LeafReader reader, DocIdSetIterator docIds) throws IOException; - private BitSet createBitSet(DocIdSetIterator iterator, Bits liveDocs, int maxDoc) throws IOException { if (liveDocs == null && iterator instanceof BitSetIterator bitSetIterator) { @@ -264,7 +181,6 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search protected abstract TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, - DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException; @@ -386,15 +302,12 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; AbstractKnnVectorQuery that = (AbstractKnnVectorQuery) o; - return k == that.k - && Objects.equals(field, that.field) - && Objects.equals(filter, that.filter) - && Objects.equals(seed, that.seed); + return k == that.k && Objects.equals(field, that.field) && Objects.equals(filter, that.filter); } @Override public int hashCode() { - return Objects.hash(field, k, filter, seed); + return Objects.hash(field, k, filter); } /** @@ -419,13 +332,6 @@ public Query getFilter() { return filter; } - /** - * @return the query that seeds the kNN search. - */ - public Query getSeed() { - return seed; - } - /** Caches the results of a KnnVector search: a list of docs and their scores */ static class DocAndScoreQuery extends Query { @@ -585,44 +491,4 @@ public int hashCode() { classHash(), contextIdentity, Arrays.hashCode(docs), Arrays.hashCode(scores)); } } - - private static class TopDocsDISI extends DocIdSetIterator { - private final List sortedDocIdList; - private int idx = -1; - - public TopDocsDISI(TopDocs topDocs) { - sortedDocIdList = new ArrayList(topDocs.scoreDocs.length); - for (int i = 0; i < topDocs.scoreDocs.length; i++) { - sortedDocIdList.add(topDocs.scoreDocs[i].doc); - } - Collections.sort(sortedDocIdList); - } - - @Override - public int advance(int target) throws IOException { - return slowAdvance(target); - } - - @Override - public long cost() { - return sortedDocIdList.size(); - } - - @Override - public int docID() { - if (idx == -1) { - return -1; - } else if (idx >= sortedDocIdList.size()) { - return DocIdSetIterator.NO_MORE_DOCS; - } else { - return sortedDocIdList.get(idx); - } - } - - @Override - public int nextDoc() throws IOException { - idx += 1; - return docID(); - } - } } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java index eb637da0d75e..05157ab65cb5 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnByteVectorQuery.java @@ -46,7 +46,7 @@ public class KnnByteVectorQuery extends AbstractKnnVectorQuery { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; - private final byte[] target; + protected final byte[] target; /** * Find the k nearest documents to the target vector according to the vectors in the @@ -72,22 +72,7 @@ public KnnByteVectorQuery(String field, byte[] target, int k) { * @throws IllegalArgumentException if k is less than 1 */ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter) { - this(field, target, k, filter, null); - } - - /** - * Find the k nearest documents to the target vector according to the vectors in the - * given field. target vector. - * - * @param field a field that has been indexed as a {@link KnnByteVectorField}. - * @param target the target of the search - * @param k the number of documents to find - * @param filter a filter applied before the vector search - * @param seed a query that is executed to seed the vector search - * @throws IllegalArgumentException if k is less than 1 - */ - public KnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { - super(field, k, filter, seed); + super(field, k, filter); this.target = Objects.requireNonNull(target, "target"); } @@ -95,14 +80,10 @@ public KnnByteVectorQuery(String field, byte[] target, int k, Query filter, Quer protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, - DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); - if (seedDocs != null) { - knnCollector = new KnnCollector.Seeded(knnCollector, seedDocs); - } LeafReader reader = context.reader(); ByteVectorValues byteVectorValues = reader.getByteVectorValues(field); if (byteVectorValues == null) { @@ -159,18 +140,4 @@ public int hashCode() { public byte[] getTargetCopy() { return ArrayUtil.copyArray(target); } - - /** - * Returns a new iterator that maps the provided docIds to the vector ordinals. - * - *

    This method assumes that all docIds have corresponding ordinals. - * - * @lucene.internal - * @lucene.experimental - */ - @Override - protected DocIdSetIterator convertDocIdsToVectorOrdinals( - LeafReader reader, DocIdSetIterator docIds) throws IOException { - return reader.getByteVectorValues(field).convertDocIdsToVectorOrdinals(docIds); - } } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java index 43338fec72e3..a05ca6747710 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java @@ -86,26 +86,13 @@ public interface KnnCollector { */ TopDocs topDocs(); - /** - * This method returns a {@link DocIdSetIterator} over entry points that seed the KNN search, or - * {@code null} (default) to perform a full KNN search (without seeds). - * - *

    Note that the entry points should represent ordinals, rather than true document IDs. - * - * @return the seed entry points or {@code null}. - * @lucene.experimental - */ - default DocIdSetIterator getSeedEntryPoints() { - return null; - } - /** * KnnCollector.Decorator is the base class for decorators of KnnCollector objects, which extend * the object with new behaviors. * * @lucene.experimental */ - public abstract static class Decorator implements KnnCollector { + abstract class Decorator implements KnnCollector { private KnnCollector collector; public Decorator(KnnCollector collector) { @@ -151,29 +138,5 @@ public float minCompetitiveSimilarity() { public TopDocs topDocs() { return collector.topDocs(); } - - @Override - public DocIdSetIterator getSeedEntryPoints() { - return collector.getSeedEntryPoints(); - } - } - - /** - * KnnCollector.Seeded is a KnnCollector decorator that replaces the seedEntryPoints. - * - * @lucene.experimental - */ - public static class Seeded extends Decorator { - private DocIdSetIterator seedEntryPoints; - - public Seeded(KnnCollector collector, DocIdSetIterator seedEntryPoints) { - super(collector); - this.seedEntryPoints = seedEntryPoints; - } - - @Override - public DocIdSetIterator getSeedEntryPoints() { - return seedEntryPoints; - } } } diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java index 286607564531..c7d6fdb3608d 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnFloatVectorQuery.java @@ -47,7 +47,7 @@ public class KnnFloatVectorQuery extends AbstractKnnVectorQuery { private static final TopDocs NO_RESULTS = TopDocsCollector.EMPTY_TOPDOCS; - private final float[] target; + protected final float[] target; /** * Find the k nearest documents to the target vector according to the vectors in the @@ -73,22 +73,7 @@ public KnnFloatVectorQuery(String field, float[] target, int k) { * @throws IllegalArgumentException if k is less than 1 */ public KnnFloatVectorQuery(String field, float[] target, int k, Query filter) { - this(field, target, k, filter, null); - } - - /** - * Find the k nearest documents to the target vector according to the vectors in the - * given field. target vector. - * - * @param field a field that has been indexed as a {@link KnnFloatVectorField}. - * @param target the target of the search - * @param k the number of documents to find - * @param filter a filter applied before the vector search - * @param seed a query that is executed to seed the vector search - * @throws IllegalArgumentException if k is less than 1 - */ - public KnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) { - super(field, k, filter, seed); + super(field, k, filter); this.target = VectorUtil.checkFinite(Objects.requireNonNull(target, "target")); } @@ -96,14 +81,10 @@ public KnnFloatVectorQuery(String field, float[] target, int k, Query filter, Qu protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, - DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { KnnCollector knnCollector = knnCollectorManager.newCollector(visitedLimit, context); - if (seedDocs != null) { - knnCollector = new KnnCollector.Seeded(knnCollector, seedDocs); - } LeafReader reader = context.reader(); FloatVectorValues floatVectorValues = reader.getFloatVectorValues(field); if (floatVectorValues == null) { @@ -162,18 +143,4 @@ public int hashCode() { public float[] getTargetCopy() { return ArrayUtil.copyArray(target); } - - /** - * Returns a new iterator that maps the provided docIds to the vector ordinals. - * - *

    This method assumes that all docIds have corresponding ordinals. - * - * @lucene.internal - * @lucene.experimental - */ - @Override - protected DocIdSetIterator convertDocIdsToVectorOrdinals( - LeafReader reader, DocIdSetIterator docIds) throws IOException { - return reader.getFloatVectorValues(field).convertDocIdsToVectorOrdinals(docIds); - } } diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java new file mode 100644 index 000000000000..93286948b0a0 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.SeededKnnCollectorManager; + +/** + * This is a version of knn byte vector query that provides a query seed to initiate the vector + * search. NOTE: The underlying format is free to ignore the provided seed + * + * @lucene.experimental + */ +public class SeededKnnByteVectorQuery extends KnnByteVectorQuery { + private final Query seed; + private final Weight seedWeight; + + /** + * Construct a new SeededKnnFloatVectorQuery instance + * + * @param field knn byte vector field to query + * @param target the query vector + * @param k number of neighbors to return + * @param filter a filter on the neighbors to return + * @param seed a query seed to initiate the vector format search + */ + public SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { + super(field, target, k, filter); + this.seed = Objects.requireNonNull(seed); + this.seedWeight = null; + } + + SeededKnnByteVectorQuery(String field, byte[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter); + this.seed = null; + this.seedWeight = Objects.requireNonNull(seedWeight); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + SeededKnnByteVectorQuery rewritten = + new SeededKnnByteVectorQuery(field, target, k, filter, seedWeight); + return rewritten.rewrite(indexSearcher); + } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + if (seedWeight == null) { + throw new UnsupportedOperationException("must be rewritten before constructing manager"); + } + return new SeededKnnCollectorManager( + super.getKnnCollectorManager(k, searcher), + seedWeight, + k, + leaf -> leaf.getFloatVectorValues(field)); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java new file mode 100644 index 000000000000..f64e0b29bc65 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import java.util.Objects; +import org.apache.lucene.search.knn.KnnCollectorManager; +import org.apache.lucene.search.knn.SeededKnnCollectorManager; + +/** + * This is a version of knn float vector query that provides a query seed to initiate the vector + * search. NOTE: The underlying format is free to ignore the provided seed + * + * @lucene.experimental + */ +public class SeededKnnFloatVectorQuery extends KnnFloatVectorQuery { + private final Query seed; + private final Weight seedWeight; + + /** + * Construct a new SeededKnnFloatVectorQuery instance + * + * @param field knn float vector field to query + * @param target the query vector + * @param k number of neighbors to return + * @param filter a filter on the neighbors to return + * @param seed a query seed to initiate the vector format search + */ + public SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Query seed) { + super(field, target, k, filter); + this.seed = Objects.requireNonNull(seed); + this.seedWeight = null; + } + + SeededKnnFloatVectorQuery(String field, float[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter); + this.seed = null; + this.seedWeight = Objects.requireNonNull(seedWeight); + } + + @Override + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + SeededKnnFloatVectorQuery rewritten = + new SeededKnnFloatVectorQuery(field, target, k, filter, seedWeight); + return rewritten.rewrite(indexSearcher); + } + + @Override + protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher searcher) { + if (seedWeight == null) { + throw new UnsupportedOperationException("must be rewritten before constructing manager"); + } + return new SeededKnnCollectorManager( + super.getKnnCollectorManager(k, searcher), + seedWeight, + k, + leaf -> leaf.getFloatVectorValues(field)); + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java new file mode 100644 index 000000000000..40eda94c654a --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.knn; + +import org.apache.lucene.search.DocIdSetIterator; + +/** Provides entry points for the kNN search */ +public interface EntryPointProvider { + /** Iterator of valid entry points for the kNN search */ + DocIdSetIterator entryPoints(); +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java new file mode 100644 index 000000000000..ac0c643eac5c --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.knn; + +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.KnnCollector; + +public class SeededKnnCollector extends KnnCollector.Decorator implements EntryPointProvider { + private final DocIdSetIterator entryPoints; + + public SeededKnnCollector(KnnCollector collector, DocIdSetIterator entryPoints) { + super(collector); + this.entryPoints = entryPoints; + } + + @Override + public DocIdSetIterator entryPoints() { + return entryPoints; + } +} diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java new file mode 100644 index 000000000000..1cea53c41794 --- /dev/null +++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search.knn; + +import java.io.IOException; +import java.util.Arrays; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.BulkScorer; +import org.apache.lucene.search.CollectionTerminatedException; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.search.LeafCollector; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TopScoreDocCollector; +import org.apache.lucene.search.TopScoreDocCollectorManager; +import org.apache.lucene.search.Weight; +import org.apache.lucene.util.IOFunction; + +/** A {@link KnnCollectorManager} that collects results with a timeout. */ +public class SeededKnnCollectorManager implements KnnCollectorManager { + private final KnnCollectorManager delegate; + private final Weight seedWeight; + private final int k; + private final IOFunction vectorValuesSupplier; + + public SeededKnnCollectorManager( + KnnCollectorManager delegate, + Weight seedWeight, + int k, + IOFunction vectorValuesSupplier) { + this.delegate = delegate; + this.seedWeight = seedWeight; + this.k = k; + this.vectorValuesSupplier = vectorValuesSupplier; + } + + @Override + public KnnCollector newCollector(int visitedLimit, LeafReaderContext ctx) throws IOException { + // Execute the seed query + TopScoreDocCollector seedCollector = + new TopScoreDocCollectorManager( + k /* numHits */, null /* after */, Integer.MAX_VALUE /* totalHitsThreshold */) + .newCollector(); + final LeafReader leafReader = ctx.reader(); + final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx); + if (leafCollector != null) { + try { + BulkScorer scorer = seedWeight.bulkScorer(ctx); + if (scorer != null) { + scorer.score( + leafCollector, + leafReader.getLiveDocs(), + 0 /* min */, + DocIdSetIterator.NO_MORE_DOCS /* max */); + } + leafCollector.finish(); + } catch ( + @SuppressWarnings("unused") + CollectionTerminatedException e) { + } + } + + TopDocs seedTopDocs = seedCollector.topDocs(); + KnnVectorValues vectorValues = vectorValuesSupplier.apply(leafReader); + if (seedTopDocs.totalHits.value() == 0 || vectorValues == null) { + return delegate.newCollector(visitedLimit, ctx); + } + KnnVectorValues.DocIndexIterator indexIterator = vectorValues.iterator(); + DocIdSetIterator seedDocs = new MappedDISI(indexIterator, new TopDocsDISI(seedTopDocs)); + return new SeededKnnCollector(delegate.newCollector(visitedLimit, ctx), seedDocs); + } + + public static class MappedDISI extends DocIdSetIterator { + KnnVectorValues.DocIndexIterator indexedDISI; + DocIdSetIterator sourceDISI; + + public MappedDISI(KnnVectorValues.DocIndexIterator indexedDISI, DocIdSetIterator sourceDISI) { + this.indexedDISI = indexedDISI; + this.sourceDISI = sourceDISI; + } + + /** + * Advances the source iterator to the first document number that is greater than or equal to + * the provided target and returns the corresponding index. + */ + @Override + public int advance(int target) throws IOException { + int newTarget = sourceDISI.advance(target); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + + @Override + public long cost() { + return this.sourceDISI.cost(); + } + + @Override + public int docID() { + if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) { + return NO_MORE_DOCS; + } + return indexedDISI.index(); + } + + /** Advances to the next document in the source iterator and returns the corresponding index. */ + @Override + public int nextDoc() throws IOException { + int newTarget = sourceDISI.nextDoc(); + if (newTarget != NO_MORE_DOCS) { + indexedDISI.advance(newTarget); + } + return docID(); + } + } + + private static class TopDocsDISI extends DocIdSetIterator { + private final int[] sortedDocIds; + private int idx = -1; + + private TopDocsDISI(TopDocs topDocs) { + sortedDocIds = new int[topDocs.scoreDocs.length]; + for (int i = 0; i < topDocs.scoreDocs.length; i++) { + sortedDocIds[i] = topDocs.scoreDocs[i].doc; + } + Arrays.sort(sortedDocIds); + } + + @Override + public int advance(int target) throws IOException { + return slowAdvance(target); + } + + @Override + public long cost() { + return sortedDocIds.length; + } + + @Override + public int docID() { + if (idx == -1) { + return -1; + } else if (idx >= sortedDocIds.length) { + return DocIdSetIterator.NO_MORE_DOCS; + } else { + return sortedDocIds[idx]; + } + } + + @Override + public int nextDoc() { + idx += 1; + return docID(); + } + } +} diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index b3f400372b9d..d08e1165cf59 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -20,10 +20,10 @@ import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import java.io.IOException; -import java.util.ArrayList; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; import org.apache.lucene.search.TopKnnCollector; +import org.apache.lucene.search.knn.EntryPointProvider; import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; @@ -67,25 +67,23 @@ public HnswGraphSearcher(NeighborQueue candidates, BitSet visited) { public static void search( RandomVectorScorer scorer, KnnCollector knnCollector, HnswGraph graph, Bits acceptOrds) throws IOException { - ArrayList entryPointOrdInts = null; - DocIdSetIterator entryPoints = knnCollector.getSeedEntryPoints(); - if (entryPoints != null) { - entryPointOrdInts = new ArrayList(); - int entryPointOrdInt; - while ((entryPointOrdInt = entryPoints.nextDoc()) != NO_MORE_DOCS) { - entryPointOrdInts.add(entryPointOrdInt); - } - } HnswGraphSearcher graphSearcher = new HnswGraphSearcher( new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph))); - if (entryPointOrdInts == null || entryPointOrdInts.isEmpty()) { - search(scorer, knnCollector, graph, graphSearcher, acceptOrds); - } else { - int[] entryPointOrdIntsArr = entryPointOrdInts.stream().mapToInt(Integer::intValue).toArray(); + final int[] entryPoints; + if (knnCollector instanceof EntryPointProvider epp) { + DocIdSetIterator eps = epp.entryPoints(); + entryPoints = new int[(int) eps.cost()]; + int idx = 0; + int entryPointOrdInt; + while ((entryPointOrdInt = eps.nextDoc()) != NO_MORE_DOCS) { + entryPoints[idx++] = entryPointOrdInt; + } // We use provided entry point ordinals to search the complete graph (level 0) graphSearcher.searchLevel( - knnCollector, scorer, 0 /* level */, entryPointOrdIntsArr, graph, acceptOrds); + knnCollector, scorer, 0 /* level */, entryPoints, graph, acceptOrds); + } else { + search(scorer, knnCollector, graph, graphSearcher, acceptOrds); } } diff --git a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java index 92d6562dd3f3..1e485515a62b 100644 --- a/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java +++ b/lucene/core/src/test/org/apache/lucene/document/TestManyKnnDocs.java @@ -28,6 +28,7 @@ import org.apache.lucene.search.MatchAllDocsQuery; import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; +import org.apache.lucene.search.SeededKnnFloatVectorQuery; import org.apache.lucene.search.TopDocs; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FSDirectory; @@ -108,7 +109,7 @@ public void testLargeSegmentSeededExact() throws Exception { vector[1] = 1; TopDocs docs = searcher.search( - new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); assertEquals(5, docs.scoreDocs.length); String s = ""; for (int j = 0; j < docs.scoreDocs.length - 1; j++) { @@ -131,7 +132,7 @@ public void testLargeSegmentSeededNearby() throws Exception { vector[1] = 1; TopDocs docs = searcher.search( - new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); assertEquals(5, docs.scoreDocs.length); String s = ""; for (int j = 0; j < docs.scoreDocs.length - 1; j++) { @@ -154,7 +155,7 @@ public void testLargeSegmentSeededDistant() throws Exception { vector[1] = 1; TopDocs docs = searcher.search( - new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); assertEquals(5, docs.scoreDocs.length); Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); String s = ""; @@ -177,7 +178,7 @@ public void testLargeSegmentSeededNone() throws Exception { vector[1] = 1; TopDocs docs = searcher.search( - new KnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); + new SeededKnnFloatVectorQuery("field", vector, 10, filterQuery, seedQuery), 5); assertEquals(5, docs.scoreDocs.length); Document d = searcher.storedFields().document(docs.scoreDocs[0].doc); String s = ""; diff --git a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java index fb75ef9b50e8..8a0d3b65aea9 100644 --- a/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java +++ b/lucene/core/src/test/org/apache/lucene/search/BaseKnnVectorQueryTestCase.java @@ -66,15 +66,9 @@ abstract class BaseKnnVectorQueryTestCase extends LuceneTestCase { abstract AbstractKnnVectorQuery getKnnVectorQuery( String field, float[] query, int k, Query queryFilter); - abstract AbstractKnnVectorQuery getKnnVectorQuery( - String field, float[] query, int k, Query queryFilter, Query seedQuery); - abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery( String field, float[] query, int k, Query queryFilter); - abstract AbstractKnnVectorQuery getThrowingKnnVectorQuery( - String field, float[] query, int k, Query queryFilter, Query seedQuery); - AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k) { return getKnnVectorQuery(field, query, k, null); } @@ -613,91 +607,6 @@ public void testRandomWithFilter() throws IOException { } } - /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ - public void testRandomWithSeed() throws IOException { - int numDocs = 1000; - int dimension = atLeast(5); - int numIters = atLeast(10); - int numDocsWithVector = 0; - try (Directory d = newDirectoryForTest()) { - // Always use the default kNN format to have predictable behavior around when it hits - // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN - // format - // implementation. - IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); - RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); - for (int i = 0; i < numDocs; i++) { - Document doc = new Document(); - if (random().nextBoolean()) { - // Randomly skip some vectors to test the mapping from docid to ordinals - doc.add(getKnnVectorField("field", randomVector(dimension))); - numDocsWithVector += 1; - } - doc.add(new NumericDocValuesField("tag", i)); - doc.add(new IntPoint("tag", i)); - w.addDocument(doc); - } - w.forceMerge(1); - w.close(); - - try (IndexReader reader = DirectoryReader.open(d)) { - IndexSearcher searcher = newSearcher(reader); - for (int i = 0; i < numIters; i++) { - int k = random().nextInt(80) + 1; - int n = random().nextInt(100) + 1; - // we may get fewer results than requested if there are deletions, but this test doesn't - // check that - assert reader.hasDeletions() == false; - - // All documents as seeds - Query seed1 = new MatchAllDocsQuery(); - Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); - AbstractKnnVectorQuery query = - getKnnVectorQuery("field", randomVector(dimension), k, filter, seed1); - TopDocs results = searcher.search(query, n); - int expected = Math.min(Math.min(n, k), numDocsWithVector); - - assertEquals(expected, results.scoreDocs.length); - assertTrue(results.totalHits.value() >= results.scoreDocs.length); - // verify the results are in descending score order - float last = Float.MAX_VALUE; - for (ScoreDoc scoreDoc : results.scoreDocs) { - assertTrue(scoreDoc.score <= last); - last = scoreDoc.score; - } - - // Restrictive seed query -- 6 documents - Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); - query = getKnnVectorQuery("field", randomVector(dimension), k, null, seed2); - results = searcher.search(query, n); - expected = Math.min(Math.min(n, k), reader.numDocs()); - assertEquals(expected, results.scoreDocs.length); - assertTrue(results.totalHits.value() >= results.scoreDocs.length); - // verify the results are in descending score order - last = Float.MAX_VALUE; - for (ScoreDoc scoreDoc : results.scoreDocs) { - assertTrue(scoreDoc.score <= last); - last = scoreDoc.score; - } - - // No seed documents -- falls back on full approx search - Query seed3 = new MatchNoDocsQuery(); - query = getKnnVectorQuery("field", randomVector(dimension), k, null, seed3); - results = searcher.search(query, n); - expected = Math.min(Math.min(n, k), reader.numDocs()); - assertEquals(expected, results.scoreDocs.length); - assertTrue(results.totalHits.value() >= results.scoreDocs.length); - // verify the results are in descending score order - last = Float.MAX_VALUE; - for (ScoreDoc scoreDoc : results.scoreDocs) { - assertTrue(scoreDoc.score <= last); - last = scoreDoc.score; - } - } - } - } - } - /** Tests filtering when all vectors have the same score. */ public void testFilterWithSameScore() throws IOException { int numDocs = 100; diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java index 651a1219f077..21219e0e1d99 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnByteVectorQuery.java @@ -34,21 +34,9 @@ AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Que return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter); } - @Override - AbstractKnnVectorQuery getKnnVectorQuery( - String field, float[] query, int k, Query queryFilter, Query seedQuery) { - return new KnnByteVectorQuery(field, floatToBytes(query), k, queryFilter, seedQuery); - } - @Override AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { - return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, null); - } - - @Override - AbstractKnnVectorQuery getThrowingKnnVectorQuery( - String field, float[] vec, int k, Query query, Query seedQuery) { - return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, seedQuery); + return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query); } @Override @@ -73,7 +61,7 @@ Field getKnnVectorField(String name, float[] vector) { return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN); } - private static byte[] floatToBytes(float[] query) { + static byte[] floatToBytes(float[] query) { byte[] bytes = new byte[query.length]; for (int i = 0; i < query.length; i++) { assert query[i] <= Byte.MAX_VALUE && query[i] >= Byte.MIN_VALUE && (query[i] % 1) == 0 @@ -121,10 +109,10 @@ public void testVectorEncodingMismatch() throws IOException { } } - private static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { + static class ThrowingKnnVectorQuery extends KnnByteVectorQuery { - public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { - super(field, target, k, filter, seed); + public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter) { + super(field, target, k, filter); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java index 5a4b49b8e2d3..ece2b385654e 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestKnnFloatVectorQuery.java @@ -50,21 +50,9 @@ KnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query return new KnnFloatVectorQuery(field, query, k, queryFilter); } - @Override - KnnFloatVectorQuery getKnnVectorQuery( - String field, float[] query, int k, Query queryFilter, Query seedQuery) { - return new KnnFloatVectorQuery(field, query, k, queryFilter, seedQuery); - } - @Override AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { - return new ThrowingKnnVectorQuery(field, vec, k, query, null); - } - - @Override - AbstractKnnVectorQuery getThrowingKnnVectorQuery( - String field, float[] vec, int k, Query query, Query seedQuery) { - return new ThrowingKnnVectorQuery(field, vec, k, query, seedQuery); + return new ThrowingKnnVectorQuery(field, vec, k, query); } @Override @@ -271,10 +259,10 @@ public void testDocAndScoreQueryBasics() throws IOException { } } - private static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { + static class ThrowingKnnVectorQuery extends KnnFloatVectorQuery { - public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) { - super(field, target, k, filter, seed); + public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter) { + super(field, target, k, filter); } @Override diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java new file mode 100644 index 000000000000..c4ce074b57fa --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import static org.apache.lucene.search.TestKnnByteVectorQuery.floatToBytes; + +import java.io.IOException; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.KnnByteVectorField; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.TestVectorUtil; + +public class TestSeededKnnByteVectorQuery extends BaseKnnVectorQueryTestCase { + + private static final Query MATCH_NONE = new MatchNoDocsQuery(); + + @Override + AbstractKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { + return new SeededKnnByteVectorQuery(field, floatToBytes(query), k, queryFilter, MATCH_NONE); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { + return new ThrowingKnnVectorQuery(field, floatToBytes(vec), k, query, MATCH_NONE); + } + + @Override + float[] randomVector(int dim) { + byte[] b = TestVectorUtil.randomVectorBytes(dim); + float[] v = new float[b.length]; + int vi = 0; + for (int i = 0; i < v.length; i++) { + v[vi++] = b[i]; + } + return v; + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnByteVectorField(name, floatToBytes(vector), similarityFunction); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnByteVectorField(name, floatToBytes(vector), VectorSimilarityFunction.EUCLIDEAN); + } + + /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ + public void testRandomWithSeed() throws IOException { + int numDocs = 1000; + int dimension = atLeast(5); + int numIters = atLeast(10); + int numDocsWithVector = 0; + try (Directory d = newDirectoryForTest()) { + // Always use the default kNN format to have predictable behavior around when it hits + // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN + // format + // implementation. + IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); + RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (random().nextBoolean()) { + // Randomly skip some vectors to test the mapping from docid to ordinals + doc.add(getKnnVectorField("field", randomVector(dimension))); + numDocsWithVector += 1; + } + doc.add(new NumericDocValuesField("tag", i)); + doc.add(new IntPoint("tag", i)); + w.addDocument(doc); + } + w.forceMerge(1); + w.close(); + + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + for (int i = 0; i < numIters; i++) { + int k = random().nextInt(80) + 1; + int n = random().nextInt(100) + 1; + // we may get fewer results than requested if there are deletions, but this test doesn't + // check that + assert reader.hasDeletions() == false; + + // All documents as seeds + Query seed1 = new MatchAllDocsQuery(); + Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); + SeededKnnByteVectorQuery query = + new SeededKnnByteVectorQuery( + "field", floatToBytes(randomVector(dimension)), k, filter, seed1); + TopDocs results = searcher.search(query, n); + int expected = Math.min(Math.min(n, k), numDocsWithVector); + + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + float last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // Restrictive seed query -- 6 documents + Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); + query = + new SeededKnnByteVectorQuery( + "field", floatToBytes(randomVector(dimension)), k, null, seed2); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // No seed documents -- falls back on full approx search + Query seed3 = new MatchNoDocsQuery(); + query = + new SeededKnnByteVectorQuery( + "field", floatToBytes(randomVector(dimension)), k, null, seed3); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + } + } + } + } + + private static class ThrowingKnnVectorQuery extends SeededKnnByteVectorQuery { + + public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, Query seed) { + super(field, target, k, filter, seed); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { + throw new UnsupportedOperationException("exact search is not supported"); + } + + @Override + public String toString(String field) { + return null; + } + } +} diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java new file mode 100644 index 000000000000..268e7d806379 --- /dev/null +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.lucene.search; + +import java.io.IOException; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.IntPoint; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.QueryTimeout; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.tests.util.TestUtil; +import org.apache.lucene.util.TestVectorUtil; + +public class TestSeededKnnFloatVectorQuery extends BaseKnnVectorQueryTestCase { + private static final Query MATCH_NONE = new MatchNoDocsQuery(); + + @Override + KnnFloatVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter) { + return new SeededKnnFloatVectorQuery(field, query, k, queryFilter, MATCH_NONE); + } + + @Override + AbstractKnnVectorQuery getThrowingKnnVectorQuery(String field, float[] vec, int k, Query query) { + return new ThrowingKnnVectorQuery(field, vec, k, query, MATCH_NONE); + } + + @Override + float[] randomVector(int dim) { + return TestVectorUtil.randomVector(dim); + } + + @Override + Field getKnnVectorField( + String name, float[] vector, VectorSimilarityFunction similarityFunction) { + return new KnnFloatVectorField(name, vector, similarityFunction); + } + + @Override + Field getKnnVectorField(String name, float[] vector) { + return new KnnFloatVectorField(name, vector); + } + + /** Tests with random vectors and a random seed. Uses RandomIndexWriter. */ + public void testRandomWithSeed() throws IOException { + int numDocs = 1000; + int dimension = atLeast(5); + int numIters = atLeast(10); + int numDocsWithVector = 0; + try (Directory d = newDirectoryForTest()) { + // Always use the default kNN format to have predictable behavior around when it hits + // visitedLimit. This is fine since the test targets AbstractKnnVectorQuery logic, not the kNN + // format + // implementation. + IndexWriterConfig iwc = new IndexWriterConfig().setCodec(TestUtil.getDefaultCodec()); + RandomIndexWriter w = new RandomIndexWriter(random(), d, iwc); + for (int i = 0; i < numDocs; i++) { + Document doc = new Document(); + if (random().nextBoolean()) { + // Randomly skip some vectors to test the mapping from docid to ordinals + doc.add(getKnnVectorField("field", randomVector(dimension))); + numDocsWithVector += 1; + } + doc.add(new NumericDocValuesField("tag", i)); + doc.add(new IntPoint("tag", i)); + w.addDocument(doc); + } + w.forceMerge(1); + w.close(); + + try (IndexReader reader = DirectoryReader.open(d)) { + IndexSearcher searcher = newSearcher(reader); + for (int i = 0; i < numIters; i++) { + int k = random().nextInt(80) + 1; + int n = random().nextInt(100) + 1; + // we may get fewer results than requested if there are deletions, but this test doesn't + // check that + assert reader.hasDeletions() == false; + + // All documents as seeds + Query seed1 = new MatchAllDocsQuery(); + Query filter = random().nextBoolean() ? null : new MatchAllDocsQuery(); + AbstractKnnVectorQuery query = + new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, filter, seed1); + TopDocs results = searcher.search(query, n); + int expected = Math.min(Math.min(n, k), numDocsWithVector); + + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + float last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // Restrictive seed query -- 6 documents + Query seed2 = IntPoint.newRangeQuery("tag", 1, 6); + query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed2); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + + // No seed documents -- falls back on full approx search + Query seed3 = new MatchNoDocsQuery(); + query = new SeededKnnFloatVectorQuery("field", randomVector(dimension), k, null, seed3); + results = searcher.search(query, n); + expected = Math.min(Math.min(n, k), reader.numDocs()); + assertEquals(expected, results.scoreDocs.length); + assertTrue(results.totalHits.value() >= results.scoreDocs.length); + // verify the results are in descending score order + last = Float.MAX_VALUE; + for (ScoreDoc scoreDoc : results.scoreDocs) { + assertTrue(scoreDoc.score <= last); + last = scoreDoc.score; + } + } + } + } + } + + private static class ThrowingKnnVectorQuery extends SeededKnnFloatVectorQuery { + + public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) { + super(field, target, k, filter, seed); + } + + @Override + protected TopDocs exactSearch( + LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { + throw new UnsupportedOperationException("exact search is not supported"); + } + + @Override + public String toString(String field) { + return null; + } + } +} diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java index ab140be7113e..45cb8b9c88fa 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenByteKnnVectorQuery.java @@ -140,7 +140,6 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, - DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { diff --git a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java index ab2e1462c4c4..9c44a2f78566 100644 --- a/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java +++ b/lucene/join/src/java/org/apache/lucene/search/join/DiversifyingChildrenFloatKnnVectorQuery.java @@ -139,7 +139,6 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search protected TopDocs approximateSearch( LeafReaderContext context, Bits acceptDocs, - DocIdSetIterator seedDocs, int visitedLimit, KnnCollectorManager knnCollectorManager) throws IOException { From 6d0cb4fe32a64c920725d3b2ef34afa03a95a947 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 20 Dec 2024 11:40:07 -0500 Subject: [PATCH 39/41] removing unnecessary changes --- .../backward_codecs/lucene80/IndexedDISI.java | 55 ------------------- .../lucene92/OffHeapFloatVectorValues.java | 10 ---- .../lucene94/OffHeapByteVectorValues.java | 10 ---- .../lucene94/OffHeapFloatVectorValues.java | 10 ---- .../lucene/codecs/lucene90/IndexedDISI.java | 55 ------------------- .../lucene95/OffHeapByteVectorValues.java | 10 ---- .../lucene95/OffHeapFloatVectorValues.java | 10 ---- .../apache/lucene/index/ByteVectorValues.java | 13 ----- .../lucene/index/FloatVectorValues.java | 13 ----- 9 files changed, 186 deletions(-) diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java index e50376ca706a..639bdbd73339 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene80/IndexedDISI.java @@ -706,59 +706,4 @@ private static void rankSkip(IndexedDISI disi, int targetInBlock) throws IOExcep disi.word = rankWord; disi.numberOfOnes = disi.denseOrigoIndex + denseNOO; } - - /** - * Implementation of a {@link DocIdSetIterator} which maps a source iterator to the indexes - * produced by an {@link IndexedDISI}. - * - *

    This implementation assumes that all IDs produced by the source iterator are also present in - * the indexed iterator. - * - * @lucene.internal - */ - public static class MappedDISI extends DocIdSetIterator { - IndexedDISI indexedDISI; - DocIdSetIterator sourceDISI; - - public MappedDISI(IndexedDISI indexedDISI, DocIdSetIterator sourceDISI) { - this.indexedDISI = indexedDISI; - this.sourceDISI = sourceDISI; - } - - /** - * Advances the source iterator to the first document number that is greater than or equal to - * the provided target and returns the corresponding index. - */ - @Override - public int advance(int target) throws IOException { - int newTarget = sourceDISI.advance(target); - if (newTarget != NO_MORE_DOCS) { - indexedDISI.advance(newTarget); - } - return docID(); - } - - @Override - public long cost() { - return this.sourceDISI.cost(); - } - - @Override - public int docID() { - if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) { - return NO_MORE_DOCS; - } - return indexedDISI.index(); - } - - /** Advances to the next document in the source iterator and returns the corresponding index. */ - @Override - public int nextDoc() throws IOException { - int newTarget = sourceDISI.nextDoc(); - if (newTarget != NO_MORE_DOCS) { - indexedDISI.advance(newTarget); - } - return docID(); - } - } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java index bcebd8823db0..7c87bac5e54a 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene92/OffHeapFloatVectorValues.java @@ -216,11 +216,6 @@ public DocIdSetIterator iterator() { } }; } - - @Override - public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { - return new IndexedDISI.MappedDISI(disi, docIds); - } } private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { @@ -268,10 +263,5 @@ public Bits getAcceptOrds(Bits acceptDocs) { public VectorScorer scorer(float[] query) { return null; } - - @Override - public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { - return DocIdSetIterator.empty(); - } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java index 336287d700cf..0c428bb169f3 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapByteVectorValues.java @@ -230,11 +230,6 @@ public DocIdSetIterator iterator() { } }; } - - @Override - public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { - return new IndexedDISI.MappedDISI(disi, docIds); - } } private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues { @@ -282,10 +277,5 @@ public Bits getAcceptOrds(Bits acceptDocs) { public VectorScorer scorer(byte[] query) { return null; } - - @Override - public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { - return DocIdSetIterator.empty(); - } } } diff --git a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java index 2fa6daad4177..b21df901ddb6 100644 --- a/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java +++ b/lucene/backward-codecs/src/java/org/apache/lucene/backward_codecs/lucene94/OffHeapFloatVectorValues.java @@ -229,11 +229,6 @@ public DocIdSetIterator iterator() { } }; } - - @Override - public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { - return new IndexedDISI.MappedDISI(disi, docIds); - } } private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { @@ -281,10 +276,5 @@ public Bits getAcceptOrds(Bits acceptDocs) { public VectorScorer scorer(float[] query) { return null; } - - @Override - public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { - return DocIdSetIterator.empty(); - } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java index e6410595e059..dbd56125fcd1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene90/IndexedDISI.java @@ -758,59 +758,4 @@ private static void rankSkip(IndexedDISI disi, int targetInBlock) throws IOExcep disi.word = rankWord; disi.numberOfOnes = disi.denseOrigoIndex + denseNOO; } - - /** - * Implementation of a {@link DocIdSetIterator} which maps a source iterator to the indexes - * produced by an {@link IndexedDISI}. - * - *

    This implementation assumes that all IDs produced by the source iterator are also present in - * the indexed iterator. - * - * @lucene.internal - */ - public static class MappedDISI extends DocIdSetIterator { - IndexedDISI indexedDISI; - DocIdSetIterator sourceDISI; - - public MappedDISI(IndexedDISI indexedDISI, DocIdSetIterator sourceDISI) { - this.indexedDISI = indexedDISI; - this.sourceDISI = sourceDISI; - } - - /** - * Advances the source iterator to the first document number that is greater than or equal to - * the provided target and returns the corresponding index. - */ - @Override - public int advance(int target) throws IOException { - int newTarget = sourceDISI.advance(target); - if (newTarget != NO_MORE_DOCS) { - indexedDISI.advance(newTarget); - } - return docID(); - } - - @Override - public long cost() { - return this.sourceDISI.cost(); - } - - @Override - public int docID() { - if (indexedDISI.docID() == NO_MORE_DOCS || sourceDISI.docID() == NO_MORE_DOCS) { - return NO_MORE_DOCS; - } - return indexedDISI.index(); - } - - /** Advances to the next document in the source iterator and returns the corresponding index. */ - @Override - public int nextDoc() throws IOException { - int newTarget = sourceDISI.nextDoc(); - if (newTarget != NO_MORE_DOCS) { - indexedDISI.advance(newTarget); - } - return docID(); - } - } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java index 4dab8f49e534..1e78c8ea7aa2 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapByteVectorValues.java @@ -272,11 +272,6 @@ public DocIdSetIterator iterator() { } }; } - - @Override - public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { - return new IndexedDISI.MappedDISI(disi, docIds); - } } private static class EmptyOffHeapVectorValues extends OffHeapByteVectorValues { @@ -327,10 +322,5 @@ public Bits getAcceptOrds(Bits acceptDocs) { public VectorScorer scorer(byte[] query) { return null; } - - @Override - public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { - return DocIdSetIterator.empty(); - } } } diff --git a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java index 310a42e18015..2384657e93e1 100644 --- a/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/codecs/lucene95/OffHeapFloatVectorValues.java @@ -268,11 +268,6 @@ public DocIdSetIterator iterator() { } }; } - - @Override - public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { - return new IndexedDISI.MappedDISI(disi, docIds); - } } private static class EmptyOffHeapVectorValues extends OffHeapFloatVectorValues { @@ -318,10 +313,5 @@ public Bits getAcceptOrds(Bits acceptDocs) { public VectorScorer scorer(float[] query) { return null; } - - @Override - public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { - return DocIdSetIterator.empty(); - } } } diff --git a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java index aefaa7fa684f..e9be3423c181 100644 --- a/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/ByteVectorValues.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.List; import org.apache.lucene.document.KnnByteVectorField; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; /** @@ -65,18 +64,6 @@ public static void checkField(LeafReader in, String field) { } } - /** - * Returns a new iterator that maps the provided docIds to the vector ordinals. - * - *

    This method assumes that all docIds have corresponding ordinals. - * - * @lucene.internal - * @lucene.experimental - */ - public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { - return docIds; - } - /** * Return a {@link VectorScorer} for the given query vector. * diff --git a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java index 9467f8710a94..aa840fc39319 100644 --- a/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java +++ b/lucene/core/src/java/org/apache/lucene/index/FloatVectorValues.java @@ -19,7 +19,6 @@ import java.io.IOException; import java.util.List; import org.apache.lucene.document.KnnFloatVectorField; -import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.VectorScorer; /** @@ -65,18 +64,6 @@ public static void checkField(LeafReader in, String field) { } } - /** - * Returns a new iterator that maps the provided docIds to the vector ordinals. - * - *

    This method assumes that all docIds have corresponding ordinals. - * - * @lucene.internal - * @lucene.experimental - */ - public DocIdSetIterator convertDocIdsToVectorOrdinals(DocIdSetIterator docIds) { - return docIds; - } - /** * Return a {@link VectorScorer} for the given query vector and the current {@link * FloatVectorValues}. From 04289ac7b8902b24f3d27e96a3d5e4c212f8c794 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 6 Jan 2025 09:55:34 -0500 Subject: [PATCH 40/41] adding changes & address PR comments & fixing tests --- lucene/CHANGES.txt | 6 ++++- .../apache/lucene/search/KnnCollector.java | 2 +- .../search/SeededKnnByteVectorQuery.java | 11 +++++--- .../search/SeededKnnFloatVectorQuery.java | 11 +++++--- .../lucene/search/knn/EntryPointProvider.java | 3 +++ .../lucene/search/knn/SeededKnnCollector.java | 18 +++++++++++-- .../search/knn/SeededKnnCollectorManager.java | 23 +++++++++------- .../lucene/util/hnsw/HnswGraphSearcher.java | 24 ++++++++++++----- .../search/TestSeededKnnByteVectorQuery.java | 24 +++++++++++++++++ .../search/TestSeededKnnFloatVectorQuery.java | 26 ++++++++++++++++++- 10 files changed, 120 insertions(+), 28 deletions(-) diff --git a/lucene/CHANGES.txt b/lucene/CHANGES.txt index 6b01271c46b7..dc90bfa477cb 100644 --- a/lucene/CHANGES.txt +++ b/lucene/CHANGES.txt @@ -42,7 +42,11 @@ API Changes New Features --------------------- -(No changes) + +* GITHUB#14084, GITHUB#13635, GITHUB#13634: Adds new `SeededKnnByteVectorQuery` and `SeededKnnFloatVectorQuery` + queries. These queries allow for the vector search entry points to be initialized via a `seed` query. This follows + the research provided via https://arxiv.org/abs/2307.16779. (Sean MacAvaney, Ben Trent). + Improvements --------------------- diff --git a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java index a05ca6747710..f694d8f7085c 100644 --- a/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/KnnCollector.java @@ -93,7 +93,7 @@ public interface KnnCollector { * @lucene.experimental */ abstract class Decorator implements KnnCollector { - private KnnCollector collector; + private final KnnCollector collector; public Decorator(KnnCollector collector) { this.collector = collector; diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java index 93286948b0a0..1050bdf70e11 100644 --- a/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java @@ -25,14 +25,19 @@ * This is a version of knn byte vector query that provides a query seed to initiate the vector * search. NOTE: The underlying format is free to ignore the provided seed * + *

    See "Lexically-Accelerated Dense + * Retrieval" (Kulkarni, Hrishikesh and MacAvaney, Sean and Goharian, Nazli and Frieder, Ophir). + * In SIGIR '23: Proceedings of the 46th International ACM SIGIR Conference on Research and + * Development in Information Retrieval Pages 152 - 162 + * * @lucene.experimental */ public class SeededKnnByteVectorQuery extends KnnByteVectorQuery { - private final Query seed; - private final Weight seedWeight; + final Query seed; + final Weight seedWeight; /** - * Construct a new SeededKnnFloatVectorQuery instance + * Construct a new SeededKnnByteVectorQuery instance * * @param field knn byte vector field to query * @param target the query vector diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java index f64e0b29bc65..fc2750677b92 100644 --- a/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java @@ -23,13 +23,18 @@ /** * This is a version of knn float vector query that provides a query seed to initiate the vector - * search. NOTE: The underlying format is free to ignore the provided seed + * search. NOTE: The underlying format is free to ignore the provided seed. + * + *

    See "Lexically-Accelerated Dense + * Retrieval" (Kulkarni, Hrishikesh and MacAvaney, Sean and Goharian, Nazli and Frieder, Ophir). + * In SIGIR '23: Proceedings of the 46th International ACM SIGIR Conference on Research and + * Development in Information Retrieval Pages 152 - 162 * * @lucene.experimental */ public class SeededKnnFloatVectorQuery extends KnnFloatVectorQuery { - private final Query seed; - private final Weight seedWeight; + final Query seed; + final Weight seedWeight; /** * Construct a new SeededKnnFloatVectorQuery instance diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java index 40eda94c654a..9e7b44b571df 100644 --- a/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java +++ b/lucene/core/src/java/org/apache/lucene/search/knn/EntryPointProvider.java @@ -22,4 +22,7 @@ public interface EntryPointProvider { /** Iterator of valid entry points for the kNN search */ DocIdSetIterator entryPoints(); + + /** Number of valid entry points for the kNN search */ + int numberOfEntryPoints(); } diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java index ac0c643eac5c..c3c4f62901ee 100644 --- a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java +++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollector.java @@ -19,16 +19,30 @@ import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.KnnCollector; -public class SeededKnnCollector extends KnnCollector.Decorator implements EntryPointProvider { +/** + * A {@link KnnCollector} that provides seeded knn collection. See usage in {@link + * SeededKnnCollectorManager}. + * + * @lucene.experimental + */ +class SeededKnnCollector extends KnnCollector.Decorator implements EntryPointProvider { private final DocIdSetIterator entryPoints; + private final int numberOfEntryPoints; - public SeededKnnCollector(KnnCollector collector, DocIdSetIterator entryPoints) { + SeededKnnCollector( + KnnCollector collector, DocIdSetIterator entryPoints, int numberOfEntryPoints) { super(collector); this.entryPoints = entryPoints; + this.numberOfEntryPoints = numberOfEntryPoints; } @Override public DocIdSetIterator entryPoints() { return entryPoints; } + + @Override + public int numberOfEntryPoints() { + return numberOfEntryPoints; + } } diff --git a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java index 1cea53c41794..7631db6e3022 100644 --- a/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java +++ b/lucene/core/src/java/org/apache/lucene/search/knn/SeededKnnCollectorManager.java @@ -32,7 +32,11 @@ import org.apache.lucene.search.Weight; import org.apache.lucene.util.IOFunction; -/** A {@link KnnCollectorManager} that collects results with a timeout. */ +/** + * A {@link KnnCollectorManager} that provides seeded knn collection. See usage in {@link + * org.apache.lucene.search.SeededKnnFloatVectorQuery} and {@link + * org.apache.lucene.search.SeededKnnByteVectorQuery}. + */ public class SeededKnnCollectorManager implements KnnCollectorManager { private final KnnCollectorManager delegate; private final Weight seedWeight; @@ -54,9 +58,7 @@ public SeededKnnCollectorManager( public KnnCollector newCollector(int visitedLimit, LeafReaderContext ctx) throws IOException { // Execute the seed query TopScoreDocCollector seedCollector = - new TopScoreDocCollectorManager( - k /* numHits */, null /* after */, Integer.MAX_VALUE /* totalHitsThreshold */) - .newCollector(); + new TopScoreDocCollectorManager(k, null, Integer.MAX_VALUE).newCollector(); final LeafReader leafReader = ctx.reader(); final LeafCollector leafCollector = seedCollector.getLeafCollector(ctx); if (leafCollector != null) { @@ -69,28 +71,29 @@ public KnnCollector newCollector(int visitedLimit, LeafReaderContext ctx) throws 0 /* min */, DocIdSetIterator.NO_MORE_DOCS /* max */); } - leafCollector.finish(); } catch ( @SuppressWarnings("unused") CollectionTerminatedException e) { } + leafCollector.finish(); } TopDocs seedTopDocs = seedCollector.topDocs(); KnnVectorValues vectorValues = vectorValuesSupplier.apply(leafReader); + final KnnCollector delegateCollector = delegate.newCollector(visitedLimit, ctx); if (seedTopDocs.totalHits.value() == 0 || vectorValues == null) { - return delegate.newCollector(visitedLimit, ctx); + return delegateCollector; } KnnVectorValues.DocIndexIterator indexIterator = vectorValues.iterator(); DocIdSetIterator seedDocs = new MappedDISI(indexIterator, new TopDocsDISI(seedTopDocs)); - return new SeededKnnCollector(delegate.newCollector(visitedLimit, ctx), seedDocs); + return new SeededKnnCollector(delegateCollector, seedDocs, seedTopDocs.scoreDocs.length); } - public static class MappedDISI extends DocIdSetIterator { + private static class MappedDISI extends DocIdSetIterator { KnnVectorValues.DocIndexIterator indexedDISI; DocIdSetIterator sourceDISI; - public MappedDISI(KnnVectorValues.DocIndexIterator indexedDISI, DocIdSetIterator sourceDISI) { + private MappedDISI(KnnVectorValues.DocIndexIterator indexedDISI, DocIdSetIterator sourceDISI) { this.indexedDISI = indexedDISI; this.sourceDISI = sourceDISI; } @@ -110,7 +113,7 @@ public int advance(int target) throws IOException { @Override public long cost() { - return this.sourceDISI.cost(); + return sourceDISI.cost(); } @Override diff --git a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java index d08e1165cf59..e8f0d316fd81 100644 --- a/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java +++ b/lucene/core/src/java/org/apache/lucene/util/hnsw/HnswGraphSearcher.java @@ -54,8 +54,9 @@ public HnswGraphSearcher(NeighborQueue candidates, BitSet visited) { } /** - * Searches the HNSW graph for the nearest neighbors of a query vector, starting from the provided - * entry points. + * Searches the HNSW graph for the nearest neighbors of a query vector. If entry points are + * directly provided via the knnCollector, then the search will be initialized at those points. + * Otherwise, the search will discover the best entry point per the normal HNSW search algorithm. * * @param scorer the scorer to compare the query with the nodes * @param knnCollector a collector of top knn results to be returned @@ -72,16 +73,25 @@ public static void search( new NeighborQueue(knnCollector.k(), true), new SparseFixedBitSet(getGraphSize(graph))); final int[] entryPoints; if (knnCollector instanceof EntryPointProvider epp) { + if (epp.numberOfEntryPoints() <= 0) { + throw new IllegalArgumentException("The number of entry points must be > 0"); + } DocIdSetIterator eps = epp.entryPoints(); - entryPoints = new int[(int) eps.cost()]; + entryPoints = new int[epp.numberOfEntryPoints()]; int idx = 0; - int entryPointOrdInt; - while ((entryPointOrdInt = eps.nextDoc()) != NO_MORE_DOCS) { + while (idx < entryPoints.length) { + int entryPointOrdInt = eps.nextDoc(); + if (entryPointOrdInt == NO_MORE_DOCS) { + throw new IllegalArgumentException( + "The number of entry points provided is less than the number of entry points requested"); + } + assert entryPointOrdInt < getGraphSize(graph); entryPoints[idx++] = entryPointOrdInt; } + // This is an invalid case, but we should check it + assert entryPoints.length > 0; // We use provided entry point ordinals to search the complete graph (level 0) - graphSearcher.searchLevel( - knnCollector, scorer, 0 /* level */, entryPoints, graph, acceptOrds); + graphSearcher.searchLevel(knnCollector, scorer, 0, entryPoints, graph, acceptOrds); } else { search(scorer, knnCollector, graph, graphSearcher, acceptOrds); } diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java index c4ce074b57fa..d0fb8c95e035 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnByteVectorQuery.java @@ -167,6 +167,30 @@ public ThrowingKnnVectorQuery(String field, byte[] target, int k, Query filter, super(field, target, k, filter, seed); } + private ThrowingKnnVectorQuery( + String field, byte[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter, seedWeight); + } + + @Override + // This is test only and we need to overwrite the inner rewrite to throw + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + return new ThrowingKnnVectorQuery(field, target, k, filter, seedWeight) + .rewrite(indexSearcher); + } + @Override protected TopDocs exactSearch( LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { diff --git a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java index 268e7d806379..d5630037ef74 100644 --- a/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java +++ b/lucene/core/src/test/org/apache/lucene/search/TestSeededKnnFloatVectorQuery.java @@ -149,10 +149,34 @@ public void testRandomWithSeed() throws IOException { private static class ThrowingKnnVectorQuery extends SeededKnnFloatVectorQuery { - public ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) { + private ThrowingKnnVectorQuery(String field, float[] target, int k, Query filter, Query seed) { super(field, target, k, filter, seed); } + private ThrowingKnnVectorQuery( + String field, float[] target, int k, Query filter, Weight seedWeight) { + super(field, target, k, filter, seedWeight); + } + + @Override + // This is test only and we need to overwrite the inner rewrite to throw + public Query rewrite(IndexSearcher indexSearcher) throws IOException { + if (seedWeight != null) { + return super.rewrite(indexSearcher); + } + BooleanQuery.Builder booleanSeedQueryBuilder = + new BooleanQuery.Builder() + .add(seed, BooleanClause.Occur.MUST) + .add(new FieldExistsQuery(field), BooleanClause.Occur.FILTER); + if (filter != null) { + booleanSeedQueryBuilder.add(filter, BooleanClause.Occur.FILTER); + } + Query seedRewritten = indexSearcher.rewrite(booleanSeedQueryBuilder.build()); + Weight seedWeight = indexSearcher.createWeight(seedRewritten, ScoreMode.TOP_SCORES, 1f); + return new ThrowingKnnVectorQuery(field, target, k, filter, seedWeight) + .rewrite(indexSearcher); + } + @Override protected TopDocs exactSearch( LeafReaderContext context, DocIdSetIterator acceptIterator, QueryTimeout queryTimeout) { From 153ce6c313b647cdfcd1a0aa56d3b9d9bc285feb Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 8 Jan 2025 09:12:32 -0500 Subject: [PATCH 41/41] adding checks for vector value types --- .../apache/lucene/search/SeededKnnByteVectorQuery.java | 9 ++++++++- .../apache/lucene/search/SeededKnnFloatVectorQuery.java | 9 ++++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java index 1050bdf70e11..980b6869c34f 100644 --- a/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnByteVectorQuery.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.Objects; +import org.apache.lucene.index.ByteVectorValues; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.SeededKnnCollectorManager; @@ -85,6 +86,12 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search super.getKnnCollectorManager(k, searcher), seedWeight, k, - leaf -> leaf.getFloatVectorValues(field)); + leaf -> { + ByteVectorValues vv = leaf.getByteVectorValues(field); + if (vv == null) { + ByteVectorValues.checkField(leaf.getContext().reader(), field); + } + return vv; + }); } } diff --git a/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java index fc2750677b92..02a33bdcdef7 100644 --- a/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java +++ b/lucene/core/src/java/org/apache/lucene/search/SeededKnnFloatVectorQuery.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.util.Objects; +import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.search.knn.KnnCollectorManager; import org.apache.lucene.search.knn.SeededKnnCollectorManager; @@ -85,6 +86,12 @@ protected KnnCollectorManager getKnnCollectorManager(int k, IndexSearcher search super.getKnnCollectorManager(k, searcher), seedWeight, k, - leaf -> leaf.getFloatVectorValues(field)); + leaf -> { + FloatVectorValues vv = leaf.getFloatVectorValues(field); + if (vv == null) { + FloatVectorValues.checkField(leaf.getContext().reader(), field); + } + return vv; + }); } }