Skip to content

Commit

Permalink
Enable 512-bit vector support (#20)
Browse files Browse the repository at this point in the history
* 512-bit experiment

* use preferred species

* nwidth_array initial commit

* nwidth_array, make classifier array static

* WIP comparing nwidth switch vs method handle dispatch of ::step

* WIP - nwitdh

* candidate PR code

* Fix step methods visibility, remove unused MethodHandle imports
  • Loading branch information
steveatgh authored Sep 12, 2023
1 parent dddb0d7 commit 8c4c689
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 77 deletions.
73 changes: 38 additions & 35 deletions src/main/java/org/simdjson/CharactersClassifier.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,42 @@
class CharactersClassifier {

private static final byte LOW_NIBBLE_MASK = 0x0f;
private static final ByteVector WHITESPACE_TABLE = ByteVector.fromArray(
ByteVector.SPECIES_256,
new byte[]{
' ', 100, 100, 100, 17, 100, 113, 2, 100, '\t', '\n', 112, 100, '\r', 100, 100,
' ', 100, 100, 100, 17, 100, 113, 2, 100, '\t', '\n', 112, 100, '\r', 100, 100
},
0
);
private static final ByteVector OP_TABLE = ByteVector.fromArray(
ByteVector.SPECIES_256,
new byte[]{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ':', '{', ',', '}', 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ':', '{', ',', '}', 0, 0
},
0
);

private static final ByteVector WHITESPACE_TABLE =
ByteVector.fromArray(
StructuralIndexer.SPECIES,
repeat(new byte[]{' ', 100, 100, 100, 17, 100, 113, 2, 100, '\t', '\n', 112, 100, '\r', 100, 100}, StructuralIndexer.SPECIES.vectorByteSize() / 4),
0);

private static final ByteVector OP_TABLE =
ByteVector.fromArray(
StructuralIndexer.SPECIES,
repeat(new byte[]{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ':', '{', ',', '}', 0, 0}, StructuralIndexer.SPECIES.vectorByteSize() / 4),
0);

private static byte[] repeat(byte[] array, int n) {
byte[] result = new byte[n * array.length];
for (int dst = 0; dst < result.length; dst += array.length) {
System.arraycopy(array, 0, result, dst, array.length);
}
return result;
}

JsonCharacterBlock classify(ByteVector chunk0) {
VectorShuffle<Byte> chunk0Low = extractLowNibble(chunk0).toShuffle();
long whitespace = eq(chunk0, WHITESPACE_TABLE.rearrange(chunk0Low));
ByteVector curlified0 = curlify(chunk0);
long op = eq(curlified0, OP_TABLE.rearrange(chunk0Low));
return new JsonCharacterBlock(whitespace, op);
}

JsonCharacterBlock classify(ByteVector chunk0, ByteVector chunk1) {
VectorShuffle<Byte> chunk0Low = extractLowNibble(chunk0).toShuffle();
VectorShuffle<Byte> chunk1Low = extractLowNibble(chunk1).toShuffle();

long whitespace = eq(
chunk0,
WHITESPACE_TABLE.rearrange(chunk0Low),
chunk1,
WHITESPACE_TABLE.rearrange(chunk1Low)
);

long whitespace = eq(chunk0, WHITESPACE_TABLE.rearrange(chunk0Low), chunk1, WHITESPACE_TABLE.rearrange(chunk1Low));
ByteVector curlified0 = curlify(chunk0);
ByteVector curlified1 = curlify(chunk1);
long op = eq(
curlified0,
OP_TABLE.rearrange(chunk0Low),
curlified1,
OP_TABLE.rearrange(chunk1Low)
);

long op = eq(curlified0, OP_TABLE.rearrange(chunk0Low), curlified1, OP_TABLE.rearrange(chunk1Low));
return new JsonCharacterBlock(whitespace, op);
}

Expand All @@ -55,9 +54,13 @@ private ByteVector curlify(ByteVector vector) {
return vector.or((byte) 0x20);
}

private long eq(ByteVector chunk0, ByteVector mask0) {
return chunk0.eq(mask0).toLong();
}

private long eq(ByteVector chunk0, ByteVector mask0, ByteVector chunk1, ByteVector mask1) {
long rLo = chunk0.eq(mask0).toLong();
long rHi = chunk1.eq(mask1).toLong();
return rLo | (rHi << 32);
}
long r0 = chunk0.eq(mask0).toLong();
long r1 = chunk1.eq(mask1).toLong();
return r0 | (r1 << 32);
}
}
25 changes: 19 additions & 6 deletions src/main/java/org/simdjson/JsonStringScanner.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,17 @@ class JsonStringScanner {
private long prevEscaped = 0;

JsonStringScanner() {
VectorSpecies<Byte> species = ByteVector.SPECIES_256;
this.backslashMask = ByteVector.broadcast(species, (byte) '\\');
this.quoteMask = ByteVector.broadcast(species, (byte) '"');
this.backslashMask = ByteVector.broadcast(StructuralIndexer.SPECIES, (byte) '\\');
this.quoteMask = ByteVector.broadcast(StructuralIndexer.SPECIES, (byte) '"');
}

JsonStringBlock next(ByteVector chunk0) {
long backslash = eq(chunk0, backslashMask);
long escaped = findEscaped(backslash);
long quote = eq(chunk0, quoteMask) & ~escaped;
long inString = prefixXor(quote) ^ prevInString;
prevInString = inString >> 63;
return new JsonStringBlock(quote, inString);
}

JsonStringBlock next(ByteVector chunk0, ByteVector chunk1) {
Expand All @@ -29,10 +37,15 @@ JsonStringBlock next(ByteVector chunk0, ByteVector chunk1) {
return new JsonStringBlock(quote, inString);
}

private long eq(ByteVector chunk0, ByteVector mask) {
long r = chunk0.eq(mask).toLong();
return r;
}

private long eq(ByteVector chunk0, ByteVector chunk1, ByteVector mask) {
long rLo = chunk0.eq(mask).toLong();
long rHi = chunk1.eq(mask).toLong();
return rLo | (rHi << 32);
long r0 = chunk0.eq(mask).toLong();
long r1 = chunk1.eq(mask).toLong();
return r0 | (r1 << 32);
}

private long findEscaped(long backslash) {
Expand Down
49 changes: 42 additions & 7 deletions src/main/java/org/simdjson/StructuralIndexer.java
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
package org.simdjson;

import jdk.incubator.vector.ByteVector;
import jdk.incubator.vector.VectorSpecies;
import java.lang.invoke.MethodType;

import static jdk.incubator.vector.ByteVector.SPECIES_256;
import static jdk.incubator.vector.VectorOperators.UNSIGNED_LE;

class StructuralIndexer {

static final VectorSpecies<Byte> SPECIES;
static final int N_CHUNKS;

static {
SPECIES = ByteVector.SPECIES_PREFERRED;
N_CHUNKS = 64 / SPECIES.vectorByteSize();
if (SPECIES != ByteVector.SPECIES_256 && SPECIES != ByteVector.SPECIES_512) {
throw new IllegalArgumentException("Unsupported vector species: " + SPECIES);
}
}

private final JsonStringScanner stringScanner;
private final CharactersClassifier classifier;
private final BitIndexes bitIndexes;
Expand All @@ -22,29 +34,52 @@ class StructuralIndexer {
}

void step(byte[] buffer, int offset, int blockIndex) {
ByteVector chunk0 = ByteVector.fromArray(SPECIES_256, buffer, offset);
ByteVector chunk1 = ByteVector.fromArray(SPECIES_256, buffer, offset + 32);
switch (N_CHUNKS) {
case 1: step1(buffer, offset, blockIndex); break;
case 2: step2(buffer, offset, blockIndex); break;
default: throw new RuntimeException("Unsupported vector width: " + N_CHUNKS * 64);
}
}

private void step1(byte[] buffer, int offset, int blockIndex) {
ByteVector chunk0 = ByteVector.fromArray(ByteVector.SPECIES_512, buffer, offset);
JsonStringBlock strings = stringScanner.next(chunk0);
JsonCharacterBlock characters = classifier.classify(chunk0);
long unescaped = lteq(chunk0, (byte) 0x1F);
finishStep(characters, strings, unescaped, blockIndex);
}

private void step2(byte[] buffer, int offset, int blockIndex) {
ByteVector chunk0 = ByteVector.fromArray(ByteVector.SPECIES_256, buffer, offset);
ByteVector chunk1 = ByteVector.fromArray(ByteVector.SPECIES_256, buffer, offset + 32);
JsonStringBlock strings = stringScanner.next(chunk0, chunk1);
JsonCharacterBlock characters = classifier.classify(chunk0, chunk1);
long unescaped = lteq(chunk0, chunk1, (byte) 0x1F);
finishStep(characters, strings, unescaped, blockIndex);
}

private void finishStep(JsonCharacterBlock characters, JsonStringBlock strings, long unescaped, int blockIndex) {
long scalar = characters.scalar();
long nonQuoteScalar = scalar & ~strings.quote();
long followsNonQuoteScalar = nonQuoteScalar << 1 | prevScalar;
prevScalar = nonQuoteScalar >>> 63;
long unescaped = lteq(chunk0, chunk1, (byte) 0x1F);
// TODO: utf-8 validation
long potentialScalarStart = scalar & ~followsNonQuoteScalar;
long potentialStructuralStart = characters.op() | potentialScalarStart;
bitIndexes.write(blockIndex, prevStructurals);
prevStructurals = potentialStructuralStart & ~strings.stringTail();
unescapedCharsError |= strings.nonQuoteInsideString(unescaped);
}

private long lteq(ByteVector chunk0, byte scalar) {
long r = chunk0.compare(UNSIGNED_LE, scalar).toLong();
return r;
}

private long lteq(ByteVector chunk0, ByteVector chunk1, byte scalar) {
long rLo = chunk0.compare(UNSIGNED_LE, scalar).toLong();
long rHi = chunk1.compare(UNSIGNED_LE, scalar).toLong();
return rLo | (rHi << 32);
long r0 = chunk0.compare(UNSIGNED_LE, scalar).toLong();
long r1 = chunk1.compare(UNSIGNED_LE, scalar).toLong();
return r0 | (r1 << 32);
}

void finish(int blockIndex) {
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/org/simdjson/TapeBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
import static org.simdjson.Tape.START_OBJECT;
import static org.simdjson.Tape.STRING;
import static org.simdjson.Tape.TRUE_VALUE;
import static jdk.incubator.vector.ByteVector.SPECIES_256;

class TapeBuilder {

private static final byte SPACE = 0x20;
private static final byte BACKSLASH = '\\';
private static final byte QUOTE = '"';
private static final int BYTES_PROCESSED = 32;
private static final int BYTES_PROCESSED = StructuralIndexer.SPECIES.vectorByteSize();
private static final byte[] ESCAPE_MAP = new byte[]{
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x0.
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
Expand Down Expand Up @@ -198,7 +197,7 @@ private void visitString(byte[] buffer, int idx) {
int src = idx + 1;
int dst = stringBufferIdx + Integer.BYTES;
while (true) {
ByteVector srcVec = ByteVector.fromArray(SPECIES_256, buffer, src);
ByteVector srcVec = ByteVector.fromArray(StructuralIndexer.SPECIES, buffer, src);
srcVec.intoArray(stringBuffer, dst);
long backslashBits = srcVec.eq(BACKSLASH).toLong();
long quoteBits = srcVec.eq(QUOTE).toLong();
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/org/simdjson/BenchmarkCorrectnessTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,6 @@ public void numberParserTest(String input, Double expected) {
private static byte[] loadTwitterJson() throws IOException {
try (InputStream is = BenchmarkCorrectnessTest.class.getResourceAsStream("/twitter.json")) {
return is.readAllBytes();
}
}
}
}
18 changes: 13 additions & 5 deletions src/test/java/org/simdjson/CharactersClassifierTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;
import static org.simdjson.StringUtils.chunk0;
import static org.simdjson.StringUtils.chunk1;
import static org.simdjson.StringUtils.chunk;

public class CharactersClassifierTest {

Expand All @@ -16,7 +15,7 @@ public void classifiesOperators() {
String str = "a{bc}1:2,3[efg]aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";

// when
JsonCharacterBlock block = classifier.classify(chunk0(str), chunk1(str));
JsonCharacterBlock block = classify(classifier, str);

// then
assertThat(block.op()).isEqualTo(0x4552);
Expand All @@ -39,7 +38,7 @@ public void classifiesControlCharactersAsOperators() {
}, UTF_8);

// when
JsonCharacterBlock block = classifier.classify(chunk0(str), chunk1(str));
JsonCharacterBlock block = classify(classifier, str);

// then
assertThat(block.op()).isEqualTo(0x28);
Expand All @@ -53,10 +52,19 @@ public void classifiesWhitespaces() {
String str = "a bc\t1\n2\r3efgaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";

// when
JsonCharacterBlock block = classifier.classify(chunk0(str), chunk1(str));
JsonCharacterBlock block = classify(classifier, str);

// then
assertThat(block.whitespace()).isEqualTo(0x152);
assertThat(block.op()).isEqualTo(0);
}

private JsonCharacterBlock classify(CharactersClassifier classifier, String str) {
return switch (StructuralIndexer.N_CHUNKS) {
case 1 -> classifier.classify(chunk(str, 0));
case 2 -> classifier.classify(chunk(str, 0), chunk(str, 1));
default -> throw new RuntimeException("Unsupported chunk count: " + StructuralIndexer.N_CHUNKS);
};
}

}
33 changes: 20 additions & 13 deletions src/test/java/org/simdjson/JsonStringScannerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import org.junit.jupiter.params.provider.ValueSource;

import static org.assertj.core.api.Assertions.assertThat;
import static org.simdjson.StringUtils.chunk0;
import static org.simdjson.StringUtils.chunk1;
import static org.simdjson.StringUtils.chunk;
import static org.simdjson.StringUtils.padWithSpaces;

public class JsonStringScannerTest {
Expand All @@ -18,7 +17,7 @@ public void testUnquotedString() {
String str = padWithSpaces("abc 123");

// when
JsonStringBlock block = stringScanner.next(chunk0(str), chunk1(str));
JsonStringBlock block = next(stringScanner, str);

// then
assertThat(block.quote()).isEqualTo(0);
Expand All @@ -31,7 +30,7 @@ public void testQuotedString() {
String str = padWithSpaces("\"abc 123\"");

// when
JsonStringBlock block = stringScanner.next(chunk0(str), chunk1(str));
JsonStringBlock block = next(stringScanner, str);

// then
assertThat(block.quote()).isEqualTo(0x101);
Expand All @@ -44,7 +43,7 @@ public void testStartingQuotes() {
String str = padWithSpaces("\"abc 123");

// when
JsonStringBlock block = stringScanner.next(chunk0(str), chunk1(str));
JsonStringBlock block = next(stringScanner, str);

// then
assertThat(block.quote()).isEqualTo(0x1);
Expand All @@ -58,8 +57,8 @@ public void testQuotedStringSpanningMultipleBlocks() {
String str1 = " c0 c1 c2 c3 c4 c5 c6 c7 c8 c9 d0 d1 d2 d3 d4 d5 d6 d7 d8 d\" def";

// when
JsonStringBlock firstBlock = stringScanner.next(chunk0(str0), chunk1(str0));
JsonStringBlock secondBlock = stringScanner.next(chunk0(str1), chunk1(str1));
JsonStringBlock firstBlock = next(stringScanner, str0);
JsonStringBlock secondBlock = next(stringScanner, str1);

// then
assertThat(firstBlock.quote()).isEqualTo(0x10);
Expand All @@ -77,7 +76,7 @@ public void testEscapedQuote(String str) {
String padded = padWithSpaces(str);

// when
JsonStringBlock block = stringScanner.next(chunk0(padded), chunk1(padded));
JsonStringBlock block = next(stringScanner, padded);

// then
assertThat(block.quote()).isEqualTo(0);
Expand All @@ -91,8 +90,8 @@ public void testEscapedQuoteSpanningMultipleBlocks() {
String str1 = padWithSpaces("\"def");

// when
JsonStringBlock firstBlock = stringScanner.next(chunk0(str0), chunk1(str0));
JsonStringBlock secondBlock = stringScanner.next(chunk0(str1), chunk1(str1));
JsonStringBlock firstBlock = next(stringScanner, str0);
JsonStringBlock secondBlock = next(stringScanner, str1);

// then
assertThat(firstBlock.quote()).isEqualTo(0);
Expand All @@ -110,7 +109,7 @@ public void testUnescapedQuote(String str) {
String padded = padWithSpaces(str);

// when
JsonStringBlock block = stringScanner.next(chunk0(padded), chunk1(padded));
JsonStringBlock block = next(stringScanner, padded);

// then
assertThat(block.quote()).isEqualTo(0x1L << str.indexOf('"'));
Expand All @@ -124,11 +123,19 @@ public void testUnescapedQuoteSpanningMultipleBlocks() {
String str1 = padWithSpaces("\\\"abc");

// when
JsonStringBlock firstBlock = stringScanner.next(chunk0(str0), chunk1(str0));
JsonStringBlock secondBlock = stringScanner.next(chunk0(str1), chunk1(str1));
JsonStringBlock firstBlock = next(stringScanner, str0);
JsonStringBlock secondBlock = next(stringScanner, str1);

// then
assertThat(firstBlock.quote()).isEqualTo(0);
assertThat(secondBlock.quote()).isEqualTo(0x2);
}

private JsonStringBlock next(JsonStringScanner scanner, String str) {
return switch (StructuralIndexer.N_CHUNKS) {
case 1 -> scanner.next(chunk(str, 0));
case 2 -> scanner.next(chunk(str, 0), chunk(str, 1));
default -> throw new RuntimeException("Unsupported chunk count: " + StructuralIndexer.N_CHUNKS);
};
}
}
Loading

0 comments on commit 8c4c689

Please sign in to comment.