Skip to content

Commit

Permalink
Speed up advancing within a block, take 2. (#13958)
Browse files Browse the repository at this point in the history
PR #13692 tried to speed up advancing by using branchless binary search, but while this yielded a speedup on my machine, this yielded a slowdown on nightly benchmarks.

This PR tries a different approach using vectorization. Experimentation suggests that it speeds up queries that advance to the next few doc IDs, such as `AndHighHigh`.
  • Loading branch information
jpountz committed Oct 31, 2024
1 parent 8b87527 commit cff28d5
Show file tree
Hide file tree
Showing 9 changed files with 292 additions and 17 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ Optimizations
* GITHUB#13963: Speed up nextDoc() implementations in Lucene912PostingsReader.
(Adrien Grand)

* GITHUB#13958: Speed up advancing within a block. (Adrien Grand)

Bug Fixes
---------------------
* GITHUB#13832: Fixed an issue where the DefaultPassageFormatter.format method did not format passages as intended
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.lucene.benchmark.jmh;

import java.util.Arrays;
import java.util.Random;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.util.VectorUtil;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.CompilerControl;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Level;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
@Warmup(iterations = 5, time = 1)
@Measurement(iterations = 5, time = 1)
@Fork(
value = 3,
jvmArgsAppend = {
"-Xmx1g",
"-Xms1g",
"-XX:+AlwaysPreTouch",
"--add-modules",
"jdk.incubator.vector"
})
public class AdvanceBenchmark {

private final long[] values = new long[129];
private final int[] startIndexes = new int[1_000];
private final long[] targets = new long[startIndexes.length];

@Setup(Level.Trial)
public void setup() throws Exception {
for (int i = 0; i < 128; ++i) {
values[i] = i;
}
values[128] = DocIdSetIterator.NO_MORE_DOCS;
Random r = new Random(0);
for (int i = 0; i < startIndexes.length; ++i) {
startIndexes[i] = r.nextInt(64);
targets[i] = startIndexes[i] + 1 + r.nextInt(1 << r.nextInt(7));
}
}

@Benchmark
public void binarySearch() {
for (int i = 0; i < startIndexes.length; ++i) {
binarySearch(values, targets[i], startIndexes[i]);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int binarySearch(long[] values, long target, int startIndex) {
// Standard binary search
int i = Arrays.binarySearch(values, startIndex, values.length, target);
if (i < 0) {
i = -1 - i;
}
return i;
}

@Benchmark
public void inlinedBranchlessBinarySearch() {
for (int i = 0; i < targets.length; ++i) {
inlinedBranchlessBinarySearch(values, targets[i]);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int inlinedBranchlessBinarySearch(long[] values, long target) {
// This compiles to cmov instructions.
int start = 0;

if (values[63] < target) {
start += 64;
}
if (values[start + 31] < target) {
start += 32;
}
if (values[start + 15] < target) {
start += 16;
}
if (values[start + 7] < target) {
start += 8;
}
if (values[start + 3] < target) {
start += 4;
}
if (values[start + 1] < target) {
start += 2;
}
if (values[start] < target) {
start += 1;
}

return start;
}

@Benchmark
public void linearSearch() {
for (int i = 0; i < startIndexes.length; ++i) {
linearSearch(values, targets[i], startIndexes[i]);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int linearSearch(long[] values, long target, int startIndex) {
// Naive linear search.
for (int i = startIndex; i < values.length; ++i) {
if (values[i] >= target) {
return i;
}
}
return values.length;
}

@Benchmark
public void vectorUtilSearch() {
for (int i = 0; i < startIndexes.length; ++i) {
VectorUtil.findNextGEQ(values, 128, targets[i], startIndexes[i]);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private static int vectorUtilSearch(long[] values, long target, int startIndex) {
return VectorUtil.findNextGEQ(values, 128, target, startIndex);
}

private static void assertEquals(int expected, int actual) {
if (expected != actual) {
throw new AssertionError("Expected: " + expected + ", got " + actual);
}
}

public static void main(String[] args) {
// For testing purposes
long[] values = new long[129];
for (int i = 0; i < 128; ++i) {
values[i] = i;
}
values[128] = DocIdSetIterator.NO_MORE_DOCS;
for (int start = 0; start < 128; ++start) {
for (int targetIndex = start; targetIndex < 128; ++targetIndex) {
int actualIndex = binarySearch(values, values[targetIndex], start);
assertEquals(targetIndex, actualIndex);
actualIndex = inlinedBranchlessBinarySearch(values, values[targetIndex]);
assertEquals(targetIndex, actualIndex);
actualIndex = linearSearch(values, values[targetIndex], start);
assertEquals(targetIndex, actualIndex);
actualIndex = vectorUtilSearch(values, values[targetIndex], start);
assertEquals(targetIndex, actualIndex);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SlowImpactsEnum;
import org.apache.lucene.internal.vectorization.PostingDecodingUtil;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;
import org.apache.lucene.store.ByteArrayDataInput;
import org.apache.lucene.store.ChecksumIndexInput;
Expand All @@ -65,6 +66,8 @@
public final class Lucene912PostingsReader extends PostingsReaderBase {

static final VectorizationProvider VECTORIZATION_PROVIDER = VectorizationProvider.getInstance();
private static final VectorUtilSupport VECTOR_SUPPORT =
VECTORIZATION_PROVIDER.getVectorUtilSupport();
// Dummy impacts, composed of the maximum possible term frequency and the lowest possible
// (unsigned) norm value. This is typically used on tail blocks, which don't actually record
// impacts as the storage overhead would not be worth any query evaluation speedup, since there's
Expand Down Expand Up @@ -215,15 +218,6 @@ static void prefixSum(long[] buffer, int count, long base) {
}
}

static int findFirstGreater(long[] buffer, int target, int from) {
for (int i = from; i < BLOCK_SIZE; ++i) {
if (buffer[i] >= target) {
return i;
}
}
return BLOCK_SIZE;
}

@Override
public BlockTermState newTermState() {
return new IntBlockTermState();
Expand Down Expand Up @@ -357,6 +351,7 @@ private abstract class AbstractPostingsEnum extends PostingsEnum {
protected int docCountUpto; // number of docs in or before the current block
protected long prevDocID; // last doc ID of the previous block

protected int docBufferSize;
protected int docBufferUpto;

protected IndexInput docIn;
Expand Down Expand Up @@ -402,6 +397,7 @@ protected PostingsEnum resetIdsAndLevelParams(IntBlockTermState termState) throw
level1DocEndFP = termState.docStartFP;
}
level1DocCountUpto = 0;
docBufferSize = BLOCK_SIZE;
docBufferUpto = BLOCK_SIZE;
return this;
}
Expand Down Expand Up @@ -487,7 +483,7 @@ private void refillFullBlock() throws IOException {
docCountUpto += BLOCK_SIZE;
prevDocID = docBuffer[BLOCK_SIZE - 1];
docBufferUpto = 0;
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
}

private void refillRemainder() throws IOException {
Expand All @@ -508,6 +504,7 @@ private void refillRemainder() throws IOException {
docCountUpto += left;
}
docBufferUpto = 0;
docBufferSize = left;
freqFP = -1;
}

Expand Down Expand Up @@ -604,7 +601,7 @@ public int advance(int target) throws IOException {
}
}

int next = findFirstGreater(docBuffer, target, docBufferUpto);
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
this.doc = (int) docBuffer[next];
docBufferUpto = next + 1;
return doc;
Expand Down Expand Up @@ -782,16 +779,18 @@ private void refillDocs() throws IOException {
freqBuffer[0] = totalTermFreq;
docBuffer[1] = NO_MORE_DOCS;
docCountUpto++;
docBufferSize = 1;
} else {
// Read vInts:
PostingsUtil.readVIntBlock(docIn, docBuffer, freqBuffer, left, indexHasFreq, true);
prefixSum(docBuffer, left, prevDocID);
docBuffer[left] = NO_MORE_DOCS;
docCountUpto += left;
docBufferSize = left;
}
prevDocID = docBuffer[BLOCK_SIZE - 1];
docBufferUpto = 0;
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
}

private void skipLevel1To(int target) throws IOException {
Expand Down Expand Up @@ -951,7 +950,7 @@ public int advance(int target) throws IOException {
refillDocs();
}

int next = findFirstGreater(docBuffer, target, docBufferUpto);
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
posPendingCount += sumOverRange(freqBuffer, docBufferUpto, next + 1);
this.freq = (int) freqBuffer[next];
this.docBufferUpto = next + 1;
Expand Down Expand Up @@ -1155,6 +1154,7 @@ private abstract class BlockImpactsEnum extends ImpactsEnum {
protected int docCountUpto; // number of docs in or before the current block
protected int doc = -1; // doc we last read
protected long prevDocID = -1; // last doc ID of the previous block
protected int docBufferSize = BLOCK_SIZE;
protected int docBufferUpto = BLOCK_SIZE;

// true if we shallow-advanced to a new block that we have not decoded yet
Expand Down Expand Up @@ -1306,10 +1306,11 @@ private void refillDocs() throws IOException {
docBuffer[left] = NO_MORE_DOCS;
freqFP = -1;
docCountUpto += left;
docBufferSize = left;
}
prevDocID = docBuffer[BLOCK_SIZE - 1];
docBufferUpto = 0;
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
}

private void skipLevel1To(int target) throws IOException {
Expand Down Expand Up @@ -1437,7 +1438,7 @@ public int advance(int target) throws IOException {
needsRefilling = false;
}

int next = findFirstGreater(docBuffer, target, docBufferUpto);
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
this.doc = (int) docBuffer[next];
docBufferUpto = next + 1;
return doc;
Expand Down Expand Up @@ -1535,10 +1536,11 @@ private void refillDocs() throws IOException {
prefixSum(docBuffer, left, prevDocID);
docBuffer[left] = NO_MORE_DOCS;
docCountUpto += left;
docBufferSize = left;
}
prevDocID = docBuffer[BLOCK_SIZE - 1];
docBufferUpto = 0;
assert docBuffer[BLOCK_SIZE] == NO_MORE_DOCS;
assert docBuffer[docBufferSize] == NO_MORE_DOCS;
}

private void skipLevel1To(int target) throws IOException {
Expand Down Expand Up @@ -1669,7 +1671,7 @@ public int advance(int target) throws IOException {
needsRefilling = false;
}

int next = findFirstGreater(docBuffer, target, docBufferUpto);
int next = VECTOR_SUPPORT.findNextGEQ(docBuffer, docBufferSize, target, docBufferUpto);
posPendingCount += sumOverRange(freqBuffer, docBufferUpto, next + 1);
freq = (int) freqBuffer[next];
docBufferUpto = next + 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,4 +197,14 @@ public int squareDistance(byte[] a, byte[] b) {
}
return squareSum;
}

@Override
public int findNextGEQ(long[] buffer, int length, long target, int from) {
for (int i = from; i < length; ++i) {
if (buffer[i] >= target) {
return i;
}
}
return length;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,12 @@ public interface VectorUtilSupport {

/** Returns the sum of squared differences of the two byte vectors. */
int squareDistance(byte[] a, byte[] b);

/**
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code
* length} exclusive, find the first array index whose value is greater than or equal to {@code
* target}. This index is guaranteed to be at least {@code from}. If there is no such array index,
* {@code length} is returned.
*/
int findNextGEQ(long[] buffer, int length, long target, int from);
}
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,10 @@ public int advance(int target) throws IOException {

@Override
public int nextDoc() throws IOException {
DocIdSetIterator in = this.in;
if (in.docID() < upTo) {
return in.nextDoc();
}
return advance(in.docID() + 1);
}

Expand Down
10 changes: 10 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -307,4 +307,14 @@ public static float[] checkFinite(float[] v) {
}
return v;
}

/**
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code
* length} exclusive, find the first array index whose value is greater than or equal to {@code
* target}. This index is guaranteed to be at least {@code from}. If there is no such array index,
* {@code length} is returned.
*/
public static int findNextGEQ(long[] buffer, int length, long target, int from) {
return IMPL.findNextGEQ(buffer, length, target, from);
}
}
Loading

0 comments on commit cff28d5

Please sign in to comment.