Skip to content

Commit

Permalink
Allocate offHeap memory in dotProduct(byte[], byte[]) for unit tests …
Browse files Browse the repository at this point in the history
…if native dot-product is enabled. Simplifyy JMH benchmark code that tests native dot product. Incorporate other review feedback
  • Loading branch information
Ankur Goel committed Oct 31, 2024
1 parent 4579dea commit 3079c5d
Show file tree
Hide file tree
Showing 13 changed files with 325 additions and 264 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.lang.invoke.MethodType;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;
import org.apache.lucene.util.VectorUtil;
import org.openjdk.jmh.annotations.*;

Expand All @@ -36,6 +38,81 @@
value = 3,
jvmArgsAppend = {"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"})
public class VectorUtilBenchmark {

/**
* Used to get a MethodHandle of PanamaVectorUtilSupport.dotProduct(MemorySegment a, MemorySegment
* b). The method above will use a native C implementation of dotProduct if it is enabled via
* {@link org.apache.lucene.util.Constants#NATIVE_DOT_PRODUCT_ENABLED} AND both MemorySegment
* arguments are backed by off-heap memory. A reflection based approach is necessary to avoid
* taking a direct dependency on preview APIs in Panama which may be blocked at compile time.
*
* @return MethodHandle PanamaVectorUtilSupport.DotProduct(MemorySegment a, MemorySegment b)
*/
private static MethodHandle nativeDotProductHandle(String methodName) {
if (Runtime.version().feature() < 21) {
return null;
}
try {
final VectorUtilSupport vectorUtilSupport =
VectorizationProvider.getInstance().getVectorUtilSupport();
if (vectorUtilSupport.getClass().getName().endsWith("PanamaVectorUtilSupport")) {
MethodHandles.Lookup lookup = MethodHandles.lookup();
// A method type that computes dot-product between two off-heap vectors
// provided as native MemorySegment and returns an int score.
final var MemorySegment = "java.lang.foreign.MemorySegment";
final var methodType =
MethodType.methodType(
int.class, lookup.findClass(MemorySegment), lookup.findClass(MemorySegment));
var mh = lookup.findStatic(vectorUtilSupport.getClass(), methodName, methodType);
// Erase the type of receiver to Object so that mh.invokeExact(a, b) does not throw
// WrongMethodException.
// Here 'a' and 'b' are off-heap vectors of type MemorySegment constructed via reflection
// API.
// This minimizes the reflection overhead and brings us very close to the performance of
// direct method invocation.
mh = mh.asType(mh.type().changeParameterType(0, Object.class));
mh = mh.asType(mh.type().changeParameterType(1, Object.class));
return mh;
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
return null;
}

/**
* Get randomly initialized byte-vectors of given size in off-heap MemorySegment
*
* @param size dimension of byte-vector
* @return Object MemorySegment
*/
private static Object getOffHeapByteVector(int size) {
try {
VectorizationProvider vectorizationProvider = VectorizationProvider.getInstance();
if (vectorizationProvider.getClass().getName().endsWith("PanamaVectorizationProvider")) {
MethodHandles.Lookup lookup = MethodHandles.lookup();
// A method type that accepts numBytes and returns an off-heap vector of size 'numBytes'
// where each byte is randomly initialized
final var methodType =
MethodType.methodType(lookup.findClass("java.lang.foreign.MemorySegment"), int.class);
// The class is expected to be "PanamaVectorUtilSupport" with a static method
// "MemorySegment offHeapByteVector(int numBytes)" that returns the off-heap vector as a
// MemorySegment
Class<?> vectorUtilSupportClass = vectorizationProvider.getVectorUtilSupport().getClass();
final MethodHandle offHeapByteVector =
lookup.findStatic(vectorUtilSupportClass, "offHeapByteVector", methodType);
return offHeapByteVector.invoke(size);
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
return null;
}

private static final MethodHandle NATIVE_DOT_PRODUCT = nativeDotProductHandle("dotProduct");
private static final MethodHandle SIMPLE_NATIVE_DOT_PRODUCT =
nativeDotProductHandle("simpleNativeDotProduct");

static void compressBytes(byte[] raw, byte[] compressed) {
for (int i = 0; i < compressed.length; ++i) {
int v = (raw[i] << 4) | raw[compressed.length + i];
Expand All @@ -52,8 +129,8 @@ static void compressBytes(byte[] raw, byte[] compressed) {
private float[] floatsB;
private int expectedhalfByteDotProduct;

private Object nativeBytesA;
private Object nativeBytesB;
private Object offHeapBytesA;
private Object offHeapBytesB;

/** private Object nativeBytesA; private Object nativeBytesB; */
@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
Expand Down Expand Up @@ -94,70 +171,26 @@ public void init() {
// Java 21+ specific initialization
final int runtimeVersion = Runtime.version().feature();
if (runtimeVersion >= 21) {
// Reflection based code to eliminate the use of Preview classes in JMH benchmarks
try {
final Class<?> vectorUtilSupportClass = VectorUtil.getVectorUtilSupportClass();
final var className = "org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport";
if (vectorUtilSupportClass.getName().equals(className) == false) {
nativeBytesA = null;
nativeBytesB = null;
} else {
MethodHandles.Lookup lookup = MethodHandles.lookup();
final var MemorySegment = "java.lang.foreign.MemorySegment";
final var methodType =
MethodType.methodType(lookup.findClass(MemorySegment), byte[].class);
MethodHandle nativeMemorySegment =
lookup.findStatic(vectorUtilSupportClass, "nativeMemorySegment", methodType);
byte[] a = new byte[size];
byte[] b = new byte[size];
for (int i = 0; i < size; ++i) {
a[i] = (byte) random.nextInt(128);
b[i] = (byte) random.nextInt(128);
}
nativeBytesA = nativeMemorySegment.invoke(a);
nativeBytesB = nativeMemorySegment.invoke(b);
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
/*
Arena offHeap = Arena.ofAuto();
nativeBytesA = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
nativeBytesB = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
for (int i = 0; i < size; ++i) {
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
nativeBytesB.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
}*/
offHeapBytesA = getOffHeapByteVector(size);
offHeapBytesB = getOffHeapByteVector(size);
}
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int dot8s() {
try {
return (int) NATIVE_DOT_PRODUCT.invokeExact(offHeapBytesA, offHeapBytesB);
} catch (Throwable e) {
throw new RuntimeException(e);
}
}

/**
* High overhead (lower score) from using NATIVE_DOT_PRODUCT.invoke(nativeBytesA, nativeBytesB).
* Both nativeBytesA and nativeBytesB are offHeap MemorySegments created by invoking the method
* PanamaVectorUtilSupport.nativeMemorySegment(byte[]) which allocated these segments and copies
* bytes from the supplied byte[] to offHeap memory. The benchmark output below shows
* significantly more overhead. <b>NOTE:</b> Return type of dots8s() was set to void for the
* benchmark run to avoid boxing/unboxing overhead.
*
* <pre>
* Benchmark (size) Mode Cnt Score Error Units
* VectorUtilBenchmark.dot8s 768 thrpt 15 36.406 ± 0.496 ops/us
* </pre>
*
* Much lower overhead was observed when preview APIs were used directly in JMH benchmarking code
* and exact method invocation was made as shown below <b>return (int)
* VectorUtil.NATIVE_DOT_PRODUCT.invokeExact(nativeBytesA, nativeBytesB);</b>
*
* <pre>
* Benchmark (size) Mode Cnt Score Error Units
* VectorUtilBenchmark.dot8s 768 thrpt 15 43.662 ± 0.818 ops/us
* </pre>
*/
@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public void dot8s() {
public int simpleDot8s() {
try {
VectorUtil.NATIVE_DOT_PRODUCT.invoke(nativeBytesA, nativeBytesB);
return (int) SIMPLE_NATIVE_DOT_PRODUCT.invokeExact(offHeapBytesA, offHeapBytesB);
} catch (Throwable e) {
throw new RuntimeException(e);
}
Expand Down
2 changes: 2 additions & 0 deletions lucene/core/src/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@

exports org.apache.lucene.util.quantization;
exports org.apache.lucene.codecs.hnsw;
exports org.apache.lucene.internal.vectorization to
org.apache.lucene.benchmark.jmh;

provides org.apache.lucene.analysis.TokenizerFactory with
org.apache.lucene.analysis.standard.StandardTokenizerFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ public float getScoreCorrectionConstant(int targetOrd) throws IOException {
}
slice.seek(((long) targetOrd * byteSize) + numBytes);
slice.readFloats(scoreCorrectionConstant, 0, 1);
lastOrd = targetOrd;
return scoreCorrectionConstant[0];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ private static Optional<Module> lookupVectorModule() {
// add all possible callers here as FQCN:
private static final Set<String> VALID_CALLERS =
Set.of(
"org.apache.lucene.benchmark.jmh.VectorUtilBenchmark",
"org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil",
"org.apache.lucene.util.VectorUtil",
"org.apache.lucene.codecs.lucene912.Lucene912PostingsReader",
Expand Down
13 changes: 13 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,19 @@ private static boolean is64Bit() {
// "False")
public static final boolean NATIVE_DOT_PRODUCT_ENABLED = OS_ARCH.equalsIgnoreCase("aarch64");

public static final boolean TEST_NATIVE_DOT_PRODUCT;

static {
String v = System.getProperty("test.native.dotProduct", "false");
v = v.trim();
if (v.isEmpty() == false) {
TEST_NATIVE_DOT_PRODUCT = Boolean.parseBoolean(v);
} else {
throw new IllegalArgumentException(
"Boolean value expected for property - test.native.dotProduct");
}
}

/** true iff we know FMA has faster throughput than separate mul/add. */
public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA();

Expand Down
34 changes: 0 additions & 34 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@

package org.apache.lucene.util;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;

Expand Down Expand Up @@ -55,39 +52,8 @@ public final class VectorUtil {
public static final VectorUtilSupport IMPL =
VectorizationProvider.getInstance().getVectorUtilSupport();

// TODO: Harden this implementation and may be find a new home for this
private static MethodHandle nativeDotProduct() {
try {
final var PanamaVectorUtilSupport =
"org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport";
if (!IMPL.getClass().getName().equals(PanamaVectorUtilSupport)) {
return null;
}
MethodHandles.Lookup lookup = MethodHandles.lookup();
final var MemorySegment = "java.lang.foreign.MemorySegment";
final var methodType =
MethodType.methodType(
int.class, lookup.findClass(MemorySegment), lookup.findClass(MemorySegment));
return lookup.findStatic(IMPL.getClass(), "nativeDotProduct", methodType);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

// For use in JMH benchmark
public static final MethodHandle NATIVE_DOT_PRODUCT = nativeDotProduct();

private VectorUtil() {}

/*
Used in o.a.l.benchmark.jmh.VectorUtilBenchmark to create test vectors
in off-heap MemorySegments IF VectorUtilSupport instance supports
Panama APIs.
*/
public static Class<?> getVectorUtilSupportClass() {
return IMPL.getClass();
}

/**
* Returns the vector dot product of the two vectors.
*
Expand Down
Loading

0 comments on commit 3079c5d

Please sign in to comment.