Skip to content

Commit

Permalink
Allocate offHeap memory in dotProduct(byte[], byte[]) for unit tests …
Browse files Browse the repository at this point in the history
…if native dot-product is enabled. Simplifyy JMH benchmark code that tests native dot product. Incorporate other review feedback
  • Loading branch information
Ankur Goel committed Nov 3, 2024
1 parent 0344720 commit 78b7c1f
Show file tree
Hide file tree
Showing 22 changed files with 389 additions and 331 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/run-checks-all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
matrix:
os: [ ubuntu-latest ]
java: [ '21' ]
compiler: [ gcc ]

runs-on: ${{ matrix.os }}

Expand All @@ -38,6 +39,8 @@ jobs:
- uses: ./.github/actions/prepare-for-build

- name: Run gradle check (without tests)
env:
CC: ${{ matrix.compiler }}
run: ./gradlew check -x test -Ptask.times=true --max-workers 2


Expand All @@ -53,6 +56,7 @@ jobs:
# macos-latest: a tad slower than ubuntu and pretty much the same (?) so leaving out.
os: [ ubuntu-latest ]
java: [ '21' ]
compiler: [ gcc ]

runs-on: ${{ matrix.os }}

Expand All @@ -61,6 +65,8 @@ jobs:
- uses: ./.github/actions/prepare-for-build

- name: Run gradle tests
env:
CC: ${{ matrix.compiler }}
run: ./gradlew test "-Ptask.times=true" --max-workers 2

- name: List automatically-initialized gradle.properties
Expand Down
3 changes: 3 additions & 0 deletions gradle/testing/randomization.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ allprojects {
[propName: 'tests.forceintegervectors',
value: { -> testsDefaultVectorizationRequested() ? false : (randomVectorSize != 'default') },
description: "Forces use of integer vectors even when slow."],
// test native dot-product when running with Java 21 or greater and 'default' vector size (chosen by randomized testing)
[propName: 'test.native.dotProduct',
value: { -> testsDefaultVectorizationRequested() ? false : (randomVectorSize == 'default' && rootProject.vectorIncubatorJavaVersions.contains(rootProject.runtimeJavaVersion))}],
[propName: 'tests.defaultvectorization', value: false,
description: "Uses defaults for running tests with correct JVM settings to test Panama vectorization (tests.jvmargs, tests.vectorsize, tests.forceintegervectors)."],
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
import java.lang.invoke.MethodType;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;
import org.apache.lucene.util.VectorUtil;
import org.openjdk.jmh.annotations.*;

Expand All @@ -36,6 +38,81 @@
value = 3,
jvmArgsAppend = {"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"})
public class VectorUtilBenchmark {

/**
* Used to get a MethodHandle of PanamaVectorUtilSupport.dotProduct(MemorySegment a, MemorySegment
* b). The method above will use a native C implementation of dotProduct if it is enabled via
* {@link org.apache.lucene.util.Constants#NATIVE_DOT_PRODUCT_ENABLED} AND both MemorySegment
* arguments are backed by off-heap memory. A reflection based approach is necessary to avoid
* taking a direct dependency on preview APIs in Panama which may be blocked at compile time.
*
* @return MethodHandle PanamaVectorUtilSupport.DotProduct(MemorySegment a, MemorySegment b)
*/
private static MethodHandle nativeDotProductHandle(String methodName) {
if (Runtime.version().feature() < 21) {
return null;
}
try {
final VectorUtilSupport vectorUtilSupport =
VectorizationProvider.getInstance().getVectorUtilSupport();
if (vectorUtilSupport.getClass().getName().endsWith("PanamaVectorUtilSupport")) {
MethodHandles.Lookup lookup = MethodHandles.lookup();
// A method type that computes dot-product between two off-heap vectors
// provided as native MemorySegment and returns an int score.
final var MemorySegment = "java.lang.foreign.MemorySegment";
final var methodType =
MethodType.methodType(
int.class, lookup.findClass(MemorySegment), lookup.findClass(MemorySegment));
var mh = lookup.findStatic(vectorUtilSupport.getClass(), methodName, methodType);
// Erase the type of receiver to Object so that mh.invokeExact(a, b) does not throw
// WrongMethodException.
// Here 'a' and 'b' are off-heap vectors of type MemorySegment constructed via reflection
// API.
// This minimizes the reflection overhead and brings us very close to the performance of
// direct method invocation.
mh = mh.asType(mh.type().changeParameterType(0, Object.class));
mh = mh.asType(mh.type().changeParameterType(1, Object.class));
return mh;
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
return null;
}

/**
* Get randomly initialized byte-vectors of given size in off-heap MemorySegment
*
* @param size dimension of byte-vector
* @return Object MemorySegment
*/
private static Object getOffHeapByteVector(int size) {
try {
VectorizationProvider vectorizationProvider = VectorizationProvider.getInstance();
if (vectorizationProvider.getClass().getName().endsWith("PanamaVectorizationProvider")) {
MethodHandles.Lookup lookup = MethodHandles.lookup();
// A method type that accepts numBytes and returns an off-heap vector of size 'numBytes'
// where each byte is randomly initialized
final var methodType =
MethodType.methodType(lookup.findClass("java.lang.foreign.MemorySegment"), int.class);
// The class is expected to be "PanamaVectorUtilSupport" with a static method
// "MemorySegment offHeapByteVector(int numBytes)" that returns the off-heap vector as a
// MemorySegment
Class<?> vectorUtilSupportClass = vectorizationProvider.getVectorUtilSupport().getClass();
final MethodHandle offHeapByteVector =
lookup.findStatic(vectorUtilSupportClass, "offHeapByteVector", methodType);
return offHeapByteVector.invoke(size);
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
return null;
}

private static final MethodHandle NATIVE_DOT_PRODUCT = nativeDotProductHandle("dotProduct");
private static final MethodHandle SIMPLE_NATIVE_DOT_PRODUCT =
nativeDotProductHandle("simpleNativeDotProduct");

static void compressBytes(byte[] raw, byte[] compressed) {
for (int i = 0; i < compressed.length; ++i) {
int v = (raw[i] << 4) | raw[compressed.length + i];
Expand All @@ -52,8 +129,8 @@ static void compressBytes(byte[] raw, byte[] compressed) {
private float[] floatsB;
private int expectedhalfByteDotProduct;

private Object nativeBytesA;
private Object nativeBytesB;
private Object offHeapBytesA;
private Object offHeapBytesB;

/** private Object nativeBytesA; private Object nativeBytesB; */
@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
Expand Down Expand Up @@ -94,70 +171,26 @@ public void init() {
// Java 21+ specific initialization
final int runtimeVersion = Runtime.version().feature();
if (runtimeVersion >= 21) {
// Reflection based code to eliminate the use of Preview classes in JMH benchmarks
try {
final Class<?> vectorUtilSupportClass = VectorUtil.getVectorUtilSupportClass();
final var className = "org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport";
if (vectorUtilSupportClass.getName().equals(className) == false) {
nativeBytesA = null;
nativeBytesB = null;
} else {
MethodHandles.Lookup lookup = MethodHandles.lookup();
final var MemorySegment = "java.lang.foreign.MemorySegment";
final var methodType =
MethodType.methodType(lookup.findClass(MemorySegment), byte[].class);
MethodHandle nativeMemorySegment =
lookup.findStatic(vectorUtilSupportClass, "nativeMemorySegment", methodType);
byte[] a = new byte[size];
byte[] b = new byte[size];
for (int i = 0; i < size; ++i) {
a[i] = (byte) random.nextInt(128);
b[i] = (byte) random.nextInt(128);
}
nativeBytesA = nativeMemorySegment.invoke(a);
nativeBytesB = nativeMemorySegment.invoke(b);
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
/*
Arena offHeap = Arena.ofAuto();
nativeBytesA = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
nativeBytesB = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
for (int i = 0; i < size; ++i) {
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
nativeBytesB.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
}*/
offHeapBytesA = getOffHeapByteVector(size);
offHeapBytesB = getOffHeapByteVector(size);
}
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int dot8s() {
try {
return (int) NATIVE_DOT_PRODUCT.invokeExact(offHeapBytesA, offHeapBytesB);
} catch (Throwable e) {
throw new RuntimeException(e);
}
}

/**
* High overhead (lower score) from using NATIVE_DOT_PRODUCT.invoke(nativeBytesA, nativeBytesB).
* Both nativeBytesA and nativeBytesB are offHeap MemorySegments created by invoking the method
* PanamaVectorUtilSupport.nativeMemorySegment(byte[]) which allocated these segments and copies
* bytes from the supplied byte[] to offHeap memory. The benchmark output below shows
* significantly more overhead. <b>NOTE:</b> Return type of dots8s() was set to void for the
* benchmark run to avoid boxing/unboxing overhead.
*
* <pre>
* Benchmark (size) Mode Cnt Score Error Units
* VectorUtilBenchmark.dot8s 768 thrpt 15 36.406 ± 0.496 ops/us
* </pre>
*
* Much lower overhead was observed when preview APIs were used directly in JMH benchmarking code
* and exact method invocation was made as shown below <b>return (int)
* VectorUtil.NATIVE_DOT_PRODUCT.invokeExact(nativeBytesA, nativeBytesB);</b>
*
* <pre>
* Benchmark (size) Mode Cnt Score Error Units
* VectorUtilBenchmark.dot8s 768 thrpt 15 43.662 ± 0.818 ops/us
* </pre>
*/
@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public void dot8s() {
public int simpleDot8s() {
try {
VectorUtil.NATIVE_DOT_PRODUCT.invoke(nativeBytesA, nativeBytesB);
return (int) SIMPLE_NATIVE_DOT_PRODUCT.invokeExact(offHeapBytesA, offHeapBytesB);
} catch (Throwable e) {
throw new RuntimeException(e);
}
Expand Down
4 changes: 2 additions & 2 deletions lucene/core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ dependencies {

test {
build {
dependsOn ':lucene:native:build'
dependsOn ':lucene:misc:build'
}
systemProperty(
"java.library.path",
project(":lucene:native").layout.buildDirectory.get().asFile.absolutePath + "/libs/dotProduct/shared"
project(":lucene:misc").layout.buildDirectory.get().asFile.absolutePath + "/libs/dotProduct/shared"
)
}
2 changes: 2 additions & 0 deletions lucene/core/src/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@

exports org.apache.lucene.util.quantization;
exports org.apache.lucene.codecs.hnsw;
exports org.apache.lucene.internal.vectorization to
org.apache.lucene.benchmark.jmh;

provides org.apache.lucene.analysis.TokenizerFactory with
org.apache.lucene.analysis.standard.StandardTokenizerFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ public float getScoreCorrectionConstant(int targetOrd) throws IOException {
}
slice.seek(((long) targetOrd * byteSize) + numBytes);
slice.readFloats(scoreCorrectionConstant, 0, 1);
lastOrd = targetOrd;
return scoreCorrectionConstant[0];
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ private static Optional<Module> lookupVectorModule() {
// add all possible callers here as FQCN:
private static final Set<String> VALID_CALLERS =
Set.of(
"org.apache.lucene.benchmark.jmh.VectorUtilBenchmark",
"org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil",
"org.apache.lucene.util.VectorUtil",
"org.apache.lucene.codecs.lucene101.Lucene101PostingsReader",
Expand Down
13 changes: 13 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,19 @@ private static boolean is64Bit() {
// "False")
public static final boolean NATIVE_DOT_PRODUCT_ENABLED = OS_ARCH.equalsIgnoreCase("aarch64");

public static final boolean TEST_NATIVE_DOT_PRODUCT;

static {
String v = System.getProperty("test.native.dotProduct", "false");
v = v.trim();
if (v.isEmpty() == false) {
TEST_NATIVE_DOT_PRODUCT = Boolean.parseBoolean(v);
} else {
throw new IllegalArgumentException(
"Boolean value expected for property - test.native.dotProduct");
}
}

/** true iff we know FMA has faster throughput than separate mul/add. */
public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA();

Expand Down
34 changes: 0 additions & 34 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,6 @@

package org.apache.lucene.util;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;

Expand Down Expand Up @@ -55,39 +52,8 @@ public final class VectorUtil {
public static final VectorUtilSupport IMPL =
VectorizationProvider.getInstance().getVectorUtilSupport();

// TODO: Harden this implementation and may be find a new home for this
private static MethodHandle nativeDotProduct() {
try {
final var PanamaVectorUtilSupport =
"org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport";
if (!IMPL.getClass().getName().equals(PanamaVectorUtilSupport)) {
return null;
}
MethodHandles.Lookup lookup = MethodHandles.lookup();
final var MemorySegment = "java.lang.foreign.MemorySegment";
final var methodType =
MethodType.methodType(
int.class, lookup.findClass(MemorySegment), lookup.findClass(MemorySegment));
return lookup.findStatic(IMPL.getClass(), "nativeDotProduct", methodType);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

// For use in JMH benchmark
public static final MethodHandle NATIVE_DOT_PRODUCT = nativeDotProduct();

private VectorUtil() {}

/*
Used in o.a.l.benchmark.jmh.VectorUtilBenchmark to create test vectors
in off-heap MemorySegments IF VectorUtilSupport instance supports
Panama APIs.
*/
public static Class<?> getVectorUtilSupportClass() {
return IMPL.getClass();
}

/**
* Returns the vector dot product of the two vectors.
*
Expand Down
Loading

0 comments on commit 78b7c1f

Please sign in to comment.