Skip to content

Commit

Permalink
New JMH benchmark method - vdot8s that implement int8 dotProduct in C…
Browse files Browse the repository at this point in the history
… using Neon intrinsics
  • Loading branch information
Ankur Goel committed Jul 18, 2024
1 parent dea5f28 commit 5490b8e
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 23 deletions.
2 changes: 2 additions & 0 deletions gradle/testing/defaults-tests.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ allprojects {
":lucene:test-framework"
] ? 'ALL-UNNAMED' : 'org.apache.lucene.core')

jvmArgs '-Djava.library.path=' + file("${buildDir}/libs/dotProduct/shared").absolutePath

def loggingConfigFile = layout.projectDirectory.file("${resources}/logging.properties")
def tempDir = layout.projectDirectory.dir(testsTmpDir.toString())
jvmArgumentProviders.add(
Expand Down
5 changes: 1 addition & 4 deletions gradle/testing/randomization/policies/tests.policy
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ grant codeBase "file:${gradle.worker.jar}" {
};

grant {
// Allow reading gradle worker JAR.
permission java.io.FilePermission "${gradle.worker.jar}", "read";
// Allow reading from classpath JARs (resources).
permission java.io.FilePermission "${gradle.user.home}${/}-", "read";
permission java.security.AllPermission;
};

// Grant permissions to certain test-related JARs (https://github.com/apache/lucene/pull/13146)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.util.VectorUtil;
Expand Down Expand Up @@ -93,8 +94,12 @@ public void init() {
}

Arena offHeap = Arena.ofAuto();
nativeBytesA = offHeap.allocate(size);
nativeBytesB = offHeap.allocate(size);
nativeBytesA = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
nativeBytesB = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
for (int i = 0; i < size; ++i) {
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
}
}

@Benchmark
Expand All @@ -103,6 +108,12 @@ public int vdot8s() {
return VectorUtil.vdot8s(nativeBytesA, nativeBytesB, size);
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int dot8s() {
return VectorUtil.dot8s(nativeBytesA, nativeBytesB, size);
}

@Benchmark
public float binaryCosineScalar() {
return VectorUtil.cosine(bytesA, bytesB);
Expand Down
18 changes: 16 additions & 2 deletions lucene/core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ model {
toolChains {
gcc(Gcc) {
target("linux_aarch64"){
cppCompiler.withArguments { args ->
args << "-O3 --shared"
path '/usr/bin/'
cCompiler.executable 'gcc10-cc'
cCompiler.withArguments { args ->
args << "--shared"
<< "-O3"
<< "-march=armv8.2-a+dotprod"
<< "-funroll-loops"
}
}
}
Expand All @@ -52,7 +57,16 @@ model {

}

test.dependsOn 'dotProductSharedLibrary'

dependencies {
moduleTestImplementation project(':lucene:codecs')
moduleTestImplementation project(':lucene:test-framework')
}

test {
systemProperty(
"java.library.path",
file("${buildDir}/libs/dotProduct/shared").absolutePath
)
}
18 changes: 3 additions & 15 deletions lucene/core/src/c/dotProduct.c
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,17 @@
// https://developer.arm.com/architectures/instruction-sets/intrinsics/
int vdot8s(char vec1[], char vec2[], int limit) {
int result = 0;
int32x4_t acc1 = vdupq_n_s32(0);
int32x4_t acc2 = vdupq_n_s32(0);
int32x4_t acc = vdupq_n_s32(0);
int i = 0;

for (; i+16 <= limit; i+=16 ) {
// Read into 8 (bit) x 16 (values) vector
int8x16_t va8 = vld1q_s8((const void*) (vec1 + i));
int8x16_t vb8 = vld1q_s8((const void*) (vec2 + i));

// Signed multiply lower halves and store into 16 (bit) x 8 (values) vector
int16x8_t va16 = vmull_s8(vget_low_s8(va8), vget_low_s8(vb8));
// Signed multiply upper halves and store into 16 (bit) x 8 (values) vector
int16x8_t vb16 = vmull_s8(vget_high_s8(va8), vget_high_s8(vb8));

// Add pair of adjacent 16 (bit) values and accumulate int 32 (bit) x 4 (values) vector
acc1 = vpadalq_s16(acc1, va16);
acc2 = vpadalq_s16(acc2, vb16);
acc = vdotq_s32(acc, va8, vb8);
}

// Add corresponding elements in two accumulators, store in 32 (bit) x 4 (values) vector
acc1 = vaddq_s32(acc1, acc2);
// REDUCE: Add every vector element in target and write result to scalar
result += vaddvq_s32(acc1);
result += vaddvq_s32(acc);

// Scalar tail. TODO: Use FMA
for (; i < limit; i++) {
Expand Down
19 changes: 19 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,26 @@ public static void add(float[] u, float[] v) {
static final FunctionDescriptor vdot8sDesc =
FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT);

static final FunctionDescriptor dot8sDesc =
FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT);

static final MethodHandle vdot8sMH =
SYMBOL_LOOKUP
.find("vdot8s")
.map(addr -> LINKER.downcallHandle(addr, vdot8sDesc))
.orElse(null);

static final MethodHandle dot8sMH =
SYMBOL_LOOKUP.find("dot8s").map(addr -> LINKER.downcallHandle(addr, dot8sDesc)).orElse(null);

static final MethodHandle vdot8s$MH() {
return requireNonNull(vdot8sMH, "vdot8s");
}

static final MethodHandle dot8s$MH() {
return requireNonNull(dot8sMH, "dot8s");
}

static <T> T requireNonNull(T obj, String symbolName) {
if (obj == null) {
throw new UnsatisfiedLinkError("unresolved symbol: " + symbolName);
Expand All @@ -215,6 +225,15 @@ public static int vdot8s(MemorySegment vec1, MemorySegment vec2, int limit) {
}
}

public static int dot8s(MemorySegment vec1, MemorySegment vec2, int limit) {
var mh$ = dot8s$MH();
try {
return (int) mh$.invokeExact(vec1, vec2, limit);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}

/** Ankur: Hacky code end * */

/**
Expand Down

0 comments on commit 5490b8e

Please sign in to comment.