Skip to content

Commit

Permalink
New JMH benchmark method - vdot8s that implement int8 dotProduct in C…
Browse files Browse the repository at this point in the history
… using Neon intrinsics
  • Loading branch information
Ankur Goel committed Jul 18, 2024
1 parent 22ca695 commit 356d194
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 8 deletions.
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.time.format.DateTimeFormatter
plugins {
id "base"
id "lucene.build-infra"
id "c"

alias(deps.plugins.dependencychecks)
alias(deps.plugins.spotless) apply false
Expand All @@ -34,6 +35,7 @@ plugins {
alias(deps.plugins.jacocolog) apply false
}


apply from: file('gradle/globals.gradle')

// General metadata.
Expand Down
10 changes: 8 additions & 2 deletions gradle/java/javac.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ allprojects { project ->

// Use 'release' flag instead of 'source' and 'target'
tasks.withType(JavaCompile) {
options.compilerArgs += ["--release", rootProject.minJavaVersion.toString()]
options.compilerArgs += ["--release", rootProject.minJavaVersion.toString(), "--enable-preview"]
}

tasks.withType(Test) {
jvmArgs += "--enable-preview"
}

// Configure warnings.
Expand Down Expand Up @@ -72,17 +76,19 @@ allprojects { project ->
"-Xdoclint:-accessibility"
]

if (project.path == ":lucene:benchmark-jmh") {
if (project.path == ":lucene:benchmark-jmh" ) {
// JMH benchmarks use JMH preprocessor and incubating modules.
} else {
// proc:none was added because of LOG4J2-1925 / JDK-8186647
options.compilerArgs += [
"-proc:none"
]

/**
if (propertyOrDefault("javac.failOnWarnings", true).toBoolean()) {
options.compilerArgs += "-Werror"
}
*/
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions gradle/testing/defaults-tests.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ allprojects {
":lucene:test-framework"
] ? 'ALL-UNNAMED' : 'org.apache.lucene.core')

jvmArgs '-Djava.library.path=' + file("${buildDir}/libs/dotProduct/shared").absolutePath

def loggingConfigFile = layout.projectDirectory.file("${resources}/logging.properties")
def tempDir = layout.projectDirectory.dir(testsTmpDir.toString())
jvmArgumentProviders.add(
Expand Down
5 changes: 1 addition & 4 deletions gradle/testing/randomization/policies/tests.policy
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,7 @@ grant codeBase "file:${gradle.worker.jar}" {
};

grant {
// Allow reading gradle worker JAR.
permission java.io.FilePermission "${gradle.worker.jar}", "read";
// Allow reading from classpath JARs (resources).
permission java.io.FilePermission "${gradle.user.home}${/}-", "read";
permission java.security.AllPermission;
};

// Grant permissions to certain test-related JARs (https://github.com/apache/lucene/pull/13146)
Expand Down
1 change: 0 additions & 1 deletion lucene/benchmark-jmh/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ tasks.matching { it.name == "forbiddenApisMain" }.configureEach {
])
}


// Skip certain infrastructure tasks that we can't use or don't care about.
tasks.matching { it.name in [
// Turn off JMH dependency checksums and licensing (it's GPL w/ classpath exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
*/
package org.apache.lucene.benchmark.jmh;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.util.VectorUtil;
Expand Down Expand Up @@ -49,7 +52,12 @@ static void compressBytes(byte[] raw, byte[] compressed) {
private float[] floatsB;
private int expectedhalfByteDotProduct;

@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
private MemorySegment nativeBytesA;

private MemorySegment nativeBytesB;

// @Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
@Param({"768"})
int size;

@Setup(Level.Iteration)
Expand Down Expand Up @@ -84,6 +92,26 @@ public void init() {
floatsA[i] = random.nextFloat();
floatsB[i] = random.nextFloat();
}

Arena offHeap = Arena.ofAuto();
nativeBytesA = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
nativeBytesB = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
for (int i = 0; i < size; ++i) {
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
}
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int vdot8s() {
return VectorUtil.vdot8s(nativeBytesA, nativeBytesB, size);
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int dot8s() {
return VectorUtil.dot8s(nativeBytesA, nativeBytesB, size);
}

@Benchmark
Expand Down
47 changes: 47 additions & 0 deletions lucene/core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,59 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
plugins {
id "c"
}

apply plugin: 'java-library'
apply plugin: 'c'

description = 'Lucene core library'
model {
toolChains {
gcc(Gcc) {
target("linux_aarch64"){
path '/usr/bin/'
cCompiler.executable 'gcc10-cc'
cCompiler.withArguments { args ->
args << "--shared"
<< "-O3"
<< "-march=armv8.2-a+dotprod"
<< "-funroll-loops"
}
}
}
}

components {
dotProduct(NativeLibrarySpec) {
sources {
c {
source {
srcDir 'src/c' // Path to your C source files
include "**/*.c"
}
exportedHeaders {
srcDir "src/c"
include "**/*.h"
}
}
}
}
}

}

test.dependsOn 'dotProductSharedLibrary'

dependencies {
moduleTestImplementation project(':lucene:codecs')
moduleTestImplementation project(':lucene:test-framework')
}

test {
systemProperty(
"java.library.path",
file("${buildDir}/libs/dotProduct/shared").absolutePath
)
}
33 changes: 33 additions & 0 deletions lucene/core/src/c/dotProduct.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
// dotProduct.c
#include <arm_neon.h>
#include <stdio.h>

// https://developer.arm.com/architectures/instruction-sets/intrinsics/
int vdot8s(char vec1[], char vec2[], int limit) {
int result = 0;
int32x4_t acc = vdupq_n_s32(0);
int i = 0;

for (; i+16 <= limit; i+=16 ) {
// Read into 8 (bit) x 16 (values) vector
int8x16_t va8 = vld1q_s8((const void*) (vec1 + i));
int8x16_t vb8 = vld1q_s8((const void*) (vec2 + i));
acc = vdotq_s32(acc, va8, vb8);
}
// REDUCE: Add every vector element in target and write result to scalar
result += vaddvq_s32(acc);

// Scalar tail. TODO: Use FMA
for (; i < limit; i++) {
result += vec1[i] * vec2[i];
}
return result;
}

int dot8s(char vec1[], char vec2[], int limit) {
int result = 0;
for (int i = 0; i < limit; i++) {
result += vec1[i] * vec2[i];
}
return result;
}
3 changes: 3 additions & 0 deletions lucene/core/src/c/dotProduct.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

int vdot8s(char vec1[], char vec2[], int limit);
int dot8s(char vec1[], char vec2[], int limit);
68 changes: 68 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@

package org.apache.lucene.util;

import static java.lang.foreign.ValueLayout.JAVA_BYTE;
import static java.lang.foreign.ValueLayout.JAVA_INT;

import java.lang.foreign.*;
import java.lang.invoke.MethodHandle;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;

Expand Down Expand Up @@ -168,6 +173,69 @@ public static void add(float[] u, float[] v) {
}
}

/** Ankur: Hacky code start */
public static final AddressLayout POINTER =
ValueLayout.ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(JAVA_BYTE));

private static final Linker LINKER = Linker.nativeLinker();
private static final SymbolLookup SYMBOL_LOOKUP;

static {
System.loadLibrary("dotProduct");
SymbolLookup loaderLookup = SymbolLookup.loaderLookup();
SYMBOL_LOOKUP = name -> loaderLookup.find(name).or(() -> LINKER.defaultLookup().find(name));
}

static final FunctionDescriptor vdot8sDesc =
FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT);

static final FunctionDescriptor dot8sDesc =
FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT);

static final MethodHandle vdot8sMH =
SYMBOL_LOOKUP
.find("vdot8s")
.map(addr -> LINKER.downcallHandle(addr, vdot8sDesc))
.orElse(null);

static final MethodHandle dot8sMH =
SYMBOL_LOOKUP.find("dot8s").map(addr -> LINKER.downcallHandle(addr, dot8sDesc)).orElse(null);

static final MethodHandle vdot8s$MH() {
return requireNonNull(vdot8sMH, "vdot8s");
}

static final MethodHandle dot8s$MH() {
return requireNonNull(dot8sMH, "dot8s");
}

static <T> T requireNonNull(T obj, String symbolName) {
if (obj == null) {
throw new UnsatisfiedLinkError("unresolved symbol: " + symbolName);
}
return obj;
}

public static int vdot8s(MemorySegment vec1, MemorySegment vec2, int limit) {
var mh$ = vdot8s$MH();
try {
return (int) mh$.invokeExact(vec1, vec2, limit);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}

public static int dot8s(MemorySegment vec1, MemorySegment vec2, int limit) {
var mh$ = dot8s$MH();
try {
return (int) mh$.invokeExact(vec1, vec2, limit);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}

/** Ankur: Hacky code end * */

/**
* Dot product computed over signed bytes.
*
Expand Down

0 comments on commit 356d194

Please sign in to comment.