Skip to content

Commit

Permalink
Move C code to native module and integrate Java code under java21
Browse files Browse the repository at this point in the history
  • Loading branch information
Ankur Goel committed Oct 13, 2024
1 parent 8a9ee42 commit 2dce12f
Show file tree
Hide file tree
Showing 20 changed files with 501 additions and 165 deletions.
10 changes: 2 additions & 8 deletions gradle/java/javac.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@ allprojects { project ->

// Use 'release' flag instead of 'source' and 'target'
tasks.withType(JavaCompile) {
options.compilerArgs += ["--release", rootProject.minJavaVersion.toString(), "--enable-preview"]
}

tasks.withType(Test) {
jvmArgs += "--enable-preview"
options.compilerArgs += ["--release", rootProject.minJavaVersion.toString()]
}

// Configure warnings.
Expand Down Expand Up @@ -76,19 +72,17 @@ allprojects { project ->
"-Xdoclint:-accessibility"
]

if (project.path == ":lucene:benchmark-jmh" ) {
if (project.path == ":lucene:benchmark-jmh") {
// JMH benchmarks use JMH preprocessor and incubating modules.
} else {
// proc:none was added because of LOG4J2-1925 / JDK-8186647
options.compilerArgs += [
"-proc:none"
]

/**
if (propertyOrDefault("javac.failOnWarnings", true).toBoolean()) {
options.compilerArgs += "-Werror"
}
*/
}
}
}
Expand Down
2 changes: 0 additions & 2 deletions gradle/testing/defaults-tests.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,6 @@ allprojects {
":lucene:test-framework"
] ? 'ALL-UNNAMED' : 'org.apache.lucene.core')

jvmArgs '-Djava.library.path=' + file("${buildDir}/libs/dotProduct/shared").absolutePath

def loggingConfigFile = layout.projectDirectory.file("${resources}/logging.properties")
def tempDir = layout.projectDirectory.dir(testsTmpDir.toString())
jvmArgumentProviders.add(
Expand Down
8 changes: 7 additions & 1 deletion gradle/testing/randomization/policies/tests.policy
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ grant {
// Needed for DirectIODirectory to retrieve block size
permission java.lang.RuntimePermission "getFileStoreAttributes";

// Needed to load native library containing optimized dot product implementation
permission java.lang.RuntimePermission "loadLibrary.dotProduct";

// TestLockFactoriesMultiJVM opens a random port on 127.0.0.1 (port 0 = ephemeral port range):
permission java.net.SocketPermission "127.0.0.1:0", "accept,listen,resolve";
// Replicator tests connect to ephemeral ports
Expand Down Expand Up @@ -104,7 +107,10 @@ grant codeBase "file:${gradle.worker.jar}" {
};

grant {
permission java.security.AllPermission;
// Allow reading gradle worker JAR.
permission java.io.FilePermission "${gradle.worker.jar}", "read";
// Allow reading from classpath JARs (resources).
permission java.io.FilePermission "${gradle.user.home}${/}-", "read";
};

// Grant permissions to certain test-related JARs (https://github.com/apache/lucene/pull/13146)
Expand Down
1 change: 1 addition & 0 deletions lucene/benchmark-jmh/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ tasks.matching { it.name == "forbiddenApisMain" }.configureEach {
])
}


// Skip certain infrastructure tasks that we can't use or don't care about.
tasks.matching { it.name in [
// Turn off JMH dependency checksums and licensing (it's GPL w/ classpath exception
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
*/
package org.apache.lucene.benchmark.jmh;

import java.lang.foreign.Arena;
import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.util.VectorUtil;
Expand Down Expand Up @@ -52,12 +52,11 @@ static void compressBytes(byte[] raw, byte[] compressed) {
private float[] floatsB;
private int expectedhalfByteDotProduct;

private MemorySegment nativeBytesA;
private Object nativeBytesA;
private Object nativeBytesB;

private MemorySegment nativeBytesB;

// @Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
@Param({"768"})
/** private Object nativeBytesA; private Object nativeBytesB; */
@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
int size;

@Setup(Level.Iteration)
Expand Down Expand Up @@ -92,20 +91,76 @@ public void init() {
floatsA[i] = random.nextFloat();
floatsB[i] = random.nextFloat();
}

Arena offHeap = Arena.ofAuto();
nativeBytesA = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
nativeBytesB = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
for (int i = 0; i < size; ++i) {
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
// Java 21+ specific initialization
final int runtimeVersion = Runtime.version().feature();
if (runtimeVersion >= 21) {
// Reflection based code to eliminate the use of Preview classes in JMH benchmarks
try {
final Class<?> vectorUtilSupportClass = VectorUtil.getVectorUtilSupportClass();
final var className = "org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport";
if (vectorUtilSupportClass.getName().equals(className) == false) {
nativeBytesA = null;
nativeBytesB = null;
} else {
MethodHandles.Lookup lookup = MethodHandles.lookup();
final var MemorySegment = "java.lang.foreign.MemorySegment";
final var methodType =
MethodType.methodType(lookup.findClass(MemorySegment), byte[].class);
MethodHandle nativeMemorySegment =
lookup.findStatic(vectorUtilSupportClass, "nativeMemorySegment", methodType);
byte[] a = new byte[size];
byte[] b = new byte[size];
for (int i = 0; i < size; ++i) {
a[i] = (byte) random.nextInt(128);
b[i] = (byte) random.nextInt(128);
}
nativeBytesA = nativeMemorySegment.invoke(a);
nativeBytesB = nativeMemorySegment.invoke(b);
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
/*
Arena offHeap = Arena.ofAuto();
nativeBytesA = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
nativeBytesB = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
for (int i = 0; i < size; ++i) {
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
nativeBytesB.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
}*/
}
}

/**
* High overhead (lower score) from using NATIVE_DOT_PRODUCT.invoke(nativeBytesA, nativeBytesB).
* Both nativeBytesA and nativeBytesB are offHeap MemorySegments created by invoking the method
* PanamaVectorUtilSupport.nativeMemorySegment(byte[]) which allocated these segments and copies
* bytes from the supplied byte[] to offHeap memory. The benchmark output below shows
* significantly more overhead. <b>NOTE:</b> Return type of dots8s() was set to void for the
* benchmark run to avoid boxing/unboxing overhead.
*
* <pre>
* Benchmark (size) Mode Cnt Score Error Units
* VectorUtilBenchmark.dot8s 768 thrpt 15 36.406 ± 0.496 ops/us
* </pre>
*
* Much lower overhead was observed when preview APIs were used directly in JMH benchmarking code
* and exact method invocation was made as shown below <b>return (int)
* VectorUtil.NATIVE_DOT_PRODUCT.invokeExact(nativeBytesA, nativeBytesB);</b>
*
* <pre>
* Benchmark (size) Mode Cnt Score Error Units
* VectorUtilBenchmark.dot8s 768 thrpt 15 43.662 ± 0.818 ops/us
* </pre>
*/
@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int dot8s() {
return VectorUtil.dot8s(nativeBytesA, nativeBytesB, size);
public void dot8s() {
try {
VectorUtil.NATIVE_DOT_PRODUCT.invoke(nativeBytesA, nativeBytesB);
} catch (Throwable e) {
throw new RuntimeException(e);
}
}

@Benchmark
Expand Down
49 changes: 4 additions & 45 deletions lucene/core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -14,63 +14,22 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
plugins {
id "c"
}

apply plugin: 'java-library'
apply plugin: 'c'

description = 'Lucene core library'
model {
binaries {
all {
cCompiler.args "--shared", "-O3", "-march=native", "-funroll-loops"
}
}

toolChains {
gcc(Gcc) {
target("linux_aarch64") {
cCompiler.executable = System.getenv("CC")
}
}
clang(Clang) {
target("osx_aarch64"){
cCompiler.executable = System.getenv("CC")
}
}
}

components {
dotProduct(NativeLibrarySpec) {
sources {
c {
source {
srcDir 'src/c' // Path to your C source files
include "**/*.c"
}
exportedHeaders {
srcDir "src/c"
include "**/*.h"
}
}
}
}
}

}

test.dependsOn 'dotProductSharedLibrary'

dependencies {
moduleTestImplementation project(':lucene:codecs')
moduleTestImplementation project(':lucene:test-framework')
}

test {
build {
dependsOn ':lucene:native:build'
}
systemProperty(
"java.library.path",
file("${buildDir}/libs/dotProduct/shared").absolutePath
project(":lucene:native").layout.buildDirectory.get().asFile.absolutePath + "/libs/dotProduct/shared"
)
}
4 changes: 4 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ private static boolean is64Bit() {
/** true iff we know VFMA has faster throughput than separate vmul/vadd. */
public static final boolean HAS_FAST_VECTOR_FMA = hasFastVectorFMA();

// TODO: <below condition> && Boolean.parseBoolean(getSysProp("lucene.useNativeDotProduct",
// "False")
public static final boolean NATIVE_DOT_PRODUCT_ENABLED = OS_ARCH.equalsIgnoreCase("aarch64");

/** true iff we know FMA has faster throughput than separate mul/add. */
public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA();

Expand Down
98 changes: 36 additions & 62 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@

package org.apache.lucene.util;

import static java.lang.foreign.ValueLayout.JAVA_BYTE;
import static java.lang.foreign.ValueLayout.JAVA_INT;

import java.lang.foreign.*;
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;

Expand Down Expand Up @@ -54,11 +52,41 @@ public final class VectorUtil {

private static final float EPSILON = 1e-4f;

private static final VectorUtilSupport IMPL =
public static final VectorUtilSupport IMPL =
VectorizationProvider.getInstance().getVectorUtilSupport();

// TODO: Harden this implementation and may be find a new home for this
private static MethodHandle nativeDotProduct() {
try {
final var PanamaVectorUtilSupport =
"org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport";
if (!IMPL.getClass().getName().equals(PanamaVectorUtilSupport)) {
return null;
}
MethodHandles.Lookup lookup = MethodHandles.lookup();
final var MemorySegment = "java.lang.foreign.MemorySegment";
final var methodType =
MethodType.methodType(
int.class, lookup.findClass(MemorySegment), lookup.findClass(MemorySegment));
return lookup.findStatic(IMPL.getClass(), "nativeDotProduct", methodType);
} catch (Exception e) {
throw new RuntimeException(e);
}
}

public static final MethodHandle NATIVE_DOT_PRODUCT = nativeDotProduct();

private VectorUtil() {}

/*
Used in o.a.l.benchmark.jmh.VectorUtilBenchmark to create test vectors
in off-heap MemorySegments IF VectorUtilSupport instance supports
Panama APIs.
*/
public static Class<?> getVectorUtilSupportClass() {
return IMPL.getClass();
}

/**
* Returns the vector dot product of the two vectors.
*
Expand Down Expand Up @@ -173,62 +201,6 @@ public static void add(float[] u, float[] v) {
}
}

/** Ankur: Hacky code start */
public static final AddressLayout POINTER =
ValueLayout.ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(JAVA_BYTE));

private static final Linker LINKER = Linker.nativeLinker();
private static final SymbolLookup SYMBOL_LOOKUP;

static {
System.loadLibrary("dotProduct");
SymbolLookup loaderLookup = SymbolLookup.loaderLookup();
SYMBOL_LOOKUP = name -> loaderLookup.find(name).or(() -> LINKER.defaultLookup().find(name));
}

static final FunctionDescriptor dot8sDesc =
FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT);

static final MethodHandle dot8sMH =
SYMBOL_LOOKUP.find("dot8s").map(addr -> LINKER.downcallHandle(addr, dot8sDesc)).orElse(null);

static final MethodHandle neonVdot8sMH =
SYMBOL_LOOKUP
.find("vdot8s_neon")
.map(addr -> LINKER.downcallHandle(addr, dot8sDesc))
.orElse(null);

static final MethodHandle sveVdot8sMH =
SYMBOL_LOOKUP
.find("vdot8s_sve")
.map(addr -> LINKER.downcallHandle(addr, dot8sDesc))
.orElse(null);

/* chosen C implementation */
static final MethodHandle dot8sImpl;

static {
if (sveVdot8sMH != null) {
dot8sImpl = sveVdot8sMH;
} else if (neonVdot8sMH != null) {
dot8sImpl = neonVdot8sMH;
} else if (dot8sMH != null) {
dot8sImpl = dot8sMH;
} else {
throw new RuntimeException("c code was not linked!");
}
}

public static int dot8s(MemorySegment vec1, MemorySegment vec2, int limit) {
try {
return (int) dot8sImpl.invokeExact(vec1, vec2, limit);
} catch (Throwable ex$) {
throw new AssertionError("should not reach here", ex$);
}
}

/** Ankur: Hacky code end * */

/**
* Dot product computed over signed bytes.
*
Expand Down Expand Up @@ -339,7 +311,9 @@ static int xorBitCountLong(byte[] a, byte[] b) {
public static float dotProductScore(byte[] a, byte[] b) {
// divide by 2 * 2^14 (maximum absolute value of product of 2 signed bytes) * len
float denom = (float) (a.length * (1 << 15));
return 0.5f + dotProduct(a, b) / denom;

int raw = dotProduct(a, b);
return 0.5f + raw / denom;
}

/**
Expand Down
Loading

0 comments on commit 2dce12f

Please sign in to comment.