From 5490b8e96de8f4928daf0c93c7d3aaae307b1d7c Mon Sep 17 00:00:00 2001 From: Ankur Goel Date: Tue, 16 Jul 2024 00:42:54 +0000 Subject: [PATCH] New JMH benchmark method - vdot8s that implement int8 dotProduct in C using Neon intrinsics --- gradle/testing/defaults-tests.gradle | 2 ++ .../randomization/policies/tests.policy | 5 +---- .../benchmark/jmh/VectorUtilBenchmark.java | 15 +++++++++++++-- lucene/core/build.gradle | 18 ++++++++++++++++-- lucene/core/src/c/dotProduct.c | 18 +++--------------- .../org/apache/lucene/util/VectorUtil.java | 19 +++++++++++++++++++ 6 files changed, 54 insertions(+), 23 deletions(-) diff --git a/gradle/testing/defaults-tests.gradle b/gradle/testing/defaults-tests.gradle index 1f3a7d8b1a07..b74536847078 100644 --- a/gradle/testing/defaults-tests.gradle +++ b/gradle/testing/defaults-tests.gradle @@ -139,6 +139,8 @@ allprojects { ":lucene:test-framework" ] ? 'ALL-UNNAMED' : 'org.apache.lucene.core') + jvmArgs '-Djava.library.path=' + file("${buildDir}/libs/dotProduct/shared").absolutePath + def loggingConfigFile = layout.projectDirectory.file("${resources}/logging.properties") def tempDir = layout.projectDirectory.dir(testsTmpDir.toString()) jvmArgumentProviders.add( diff --git a/gradle/testing/randomization/policies/tests.policy b/gradle/testing/randomization/policies/tests.policy index f8e09ba03661..6d2b60c0e9f5 100644 --- a/gradle/testing/randomization/policies/tests.policy +++ b/gradle/testing/randomization/policies/tests.policy @@ -104,10 +104,7 @@ grant codeBase "file:${gradle.worker.jar}" { }; grant { - // Allow reading gradle worker JAR. - permission java.io.FilePermission "${gradle.worker.jar}", "read"; - // Allow reading from classpath JARs (resources). - permission java.io.FilePermission "${gradle.user.home}${/}-", "read"; + permission java.security.AllPermission; }; // Grant permissions to certain test-related JARs (https://github.com/apache/lucene/pull/13146) diff --git a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java index d43a403dff09..7ac4f66adb4e 100644 --- a/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java +++ b/lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java @@ -18,6 +18,7 @@ import java.lang.foreign.Arena; import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.TimeUnit; import org.apache.lucene.util.VectorUtil; @@ -93,8 +94,12 @@ public void init() { } Arena offHeap = Arena.ofAuto(); - nativeBytesA = offHeap.allocate(size); - nativeBytesB = offHeap.allocate(size); + nativeBytesA = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment()); + nativeBytesB = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment()); + for (int i = 0; i < size; ++i) { + nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128)); + nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128)); + } } @Benchmark @@ -103,6 +108,12 @@ public int vdot8s() { return VectorUtil.vdot8s(nativeBytesA, nativeBytesB, size); } + @Benchmark + @Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"}) + public int dot8s() { + return VectorUtil.dot8s(nativeBytesA, nativeBytesB, size); + } + @Benchmark public float binaryCosineScalar() { return VectorUtil.cosine(bytesA, bytesB); diff --git a/lucene/core/build.gradle b/lucene/core/build.gradle index 21d162726d6b..6d2b28fd1c2a 100644 --- a/lucene/core/build.gradle +++ b/lucene/core/build.gradle @@ -26,8 +26,13 @@ model { toolChains { gcc(Gcc) { target("linux_aarch64"){ - cppCompiler.withArguments { args -> - args << "-O3 --shared" + path '/usr/bin/' + cCompiler.executable 'gcc10-cc' + cCompiler.withArguments { args -> + args << "--shared" + << "-O3" + << "-march=armv8.2-a+dotprod" + << "-funroll-loops" } } } @@ -52,7 +57,16 @@ model { } +test.dependsOn 'dotProductSharedLibrary' + dependencies { moduleTestImplementation project(':lucene:codecs') moduleTestImplementation project(':lucene:test-framework') } + +test { + systemProperty( + "java.library.path", + file("${buildDir}/libs/dotProduct/shared").absolutePath + ) +} diff --git a/lucene/core/src/c/dotProduct.c b/lucene/core/src/c/dotProduct.c index 5735401c9688..8dae448aac4c 100644 --- a/lucene/core/src/c/dotProduct.c +++ b/lucene/core/src/c/dotProduct.c @@ -5,29 +5,17 @@ // https://developer.arm.com/architectures/instruction-sets/intrinsics/ int vdot8s(char vec1[], char vec2[], int limit) { int result = 0; - int32x4_t acc1 = vdupq_n_s32(0); - int32x4_t acc2 = vdupq_n_s32(0); + int32x4_t acc = vdupq_n_s32(0); int i = 0; for (; i+16 <= limit; i+=16 ) { // Read into 8 (bit) x 16 (values) vector int8x16_t va8 = vld1q_s8((const void*) (vec1 + i)); int8x16_t vb8 = vld1q_s8((const void*) (vec2 + i)); - - // Signed multiply lower halves and store into 16 (bit) x 8 (values) vector - int16x8_t va16 = vmull_s8(vget_low_s8(va8), vget_low_s8(vb8)); - // Signed multiply upper halves and store into 16 (bit) x 8 (values) vector - int16x8_t vb16 = vmull_s8(vget_high_s8(va8), vget_high_s8(vb8)); - - // Add pair of adjacent 16 (bit) values and accumulate int 32 (bit) x 4 (values) vector - acc1 = vpadalq_s16(acc1, va16); - acc2 = vpadalq_s16(acc2, vb16); + acc = vdotq_s32(acc, va8, vb8); } - - // Add corresponding elements in two accumulators, store in 32 (bit) x 4 (values) vector - acc1 = vaddq_s32(acc1, acc2); // REDUCE: Add every vector element in target and write result to scalar - result += vaddvq_s32(acc1); + result += vaddvq_s32(acc); // Scalar tail. TODO: Use FMA for (; i < limit; i++) { diff --git a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java index 9d37d0f1db10..e12070aa68d3 100644 --- a/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java +++ b/lucene/core/src/java/org/apache/lucene/util/VectorUtil.java @@ -189,16 +189,26 @@ public static void add(float[] u, float[] v) { static final FunctionDescriptor vdot8sDesc = FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT); + static final FunctionDescriptor dot8sDesc = + FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT); + static final MethodHandle vdot8sMH = SYMBOL_LOOKUP .find("vdot8s") .map(addr -> LINKER.downcallHandle(addr, vdot8sDesc)) .orElse(null); + static final MethodHandle dot8sMH = + SYMBOL_LOOKUP.find("dot8s").map(addr -> LINKER.downcallHandle(addr, dot8sDesc)).orElse(null); + static final MethodHandle vdot8s$MH() { return requireNonNull(vdot8sMH, "vdot8s"); } + static final MethodHandle dot8s$MH() { + return requireNonNull(dot8sMH, "dot8s"); + } + static T requireNonNull(T obj, String symbolName) { if (obj == null) { throw new UnsatisfiedLinkError("unresolved symbol: " + symbolName); @@ -215,6 +225,15 @@ public static int vdot8s(MemorySegment vec1, MemorySegment vec2, int limit) { } } + public static int dot8s(MemorySegment vec1, MemorySegment vec2, int limit) { + var mh$ = dot8s$MH(); + try { + return (int) mh$.invokeExact(vec1, vec2, limit); + } catch (Throwable ex$) { + throw new AssertionError("should not reach here", ex$); + } + } + /** Ankur: Hacky code end * */ /**