Skip to content

Commit

Permalink
Optimize json path lookup by avoiding traversing the trie from the ro… (
Browse files Browse the repository at this point in the history
#116)

* Optimize json path lookup by avoiding traversing the trie from the root every time.
  • Loading branch information
donavdey authored Apr 5, 2024
1 parent 03d5744 commit 82c901b
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 266 deletions.
36 changes: 0 additions & 36 deletions src/main/java/dev/blaauwendraad/masker/json/JsonPathNode.java

This file was deleted.

30 changes: 9 additions & 21 deletions src/main/java/dev/blaauwendraad/masker/json/KeyContainsMasker.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import dev.blaauwendraad.masker.json.util.AsciiJsonUtil;

import javax.annotation.CheckForNull;
import java.util.Collections;

/**
* Default implementation of the {@link JsonMasker}.
*/
final class KeyContainsMasker implements JsonMasker {
final class KeyContainsMasker implements JsonMasker {
/**
* Look-up trie containing the target keys.
*/
Expand Down Expand Up @@ -46,13 +45,8 @@ public byte[] mask(byte[] input) {

KeyMaskingConfig keyMaskingConfig = maskingConfig.isInAllowMode() ? maskingConfig.getDefaultConfig() : null;
if (maskingState.jsonPathEnabled()) {
// Check for "$" JSONPath key.
keyMaskingConfig = keyMatcher.getMaskConfigIfMatched(
maskingState.getMessage(),
-1,
-1,
Collections.emptyIterator()
);
maskingState.expandCurrentJsonPath(keyMatcher.getJsonPathRootNode());
keyMaskingConfig = keyMatcher.getMaskConfigIfMatched(maskingState.getMessage(), -1, -1, maskingState.getCurrentJsonPathNode());
}

stepOverWhitespaceCharacters(maskingState);
Expand Down Expand Up @@ -118,7 +112,7 @@ private void visitValue(MaskingState maskingState, @CheckForNull KeyMaskingConfi
* {@link KeyMaskingConfig}. Otherwise, the value is not masked
*/
private void visitArray(MaskingState maskingState, @CheckForNull KeyMaskingConfig keyMaskingConfig) {
maskingState.expandCurrentJsonPathWithArray();
maskingState.expandCurrentJsonPath(keyMatcher.traverseJsonPathSegment(maskingState.getMessage(), maskingState.getCurrentJsonPathNode(), -1, -1));
while (maskingState.next()) {
stepOverWhitespaceCharacters(maskingState);
// check if we're in an empty array
Expand Down Expand Up @@ -164,13 +158,9 @@ private void visitObject(MaskingState maskingState, @CheckForNull KeyMaskingConf

int afterClosingQuoteIndex = maskingState.currentIndex();
int keyLength = afterClosingQuoteIndex - openingQuoteIndex - 2; // minus the opening and closing quotes
maskingState.expandCurrentJsonPath(openingQuoteIndex + 1, keyLength);
KeyMaskingConfig keyMaskingConfig = keyMatcher.getMaskConfigIfMatched(
maskingState.getMessage(),
openingQuoteIndex + 1, // plus one for the opening quote
keyLength,
maskingState.getCurrentJsonPath()
);
maskingState.expandCurrentJsonPath(keyMatcher.traverseJsonPathSegment(maskingState.getMessage(), maskingState.getCurrentJsonPathNode(), openingQuoteIndex + 1, keyLength));
KeyMaskingConfig keyMaskingConfig = keyMatcher.getMaskConfigIfMatched(maskingState.getMessage(), openingQuoteIndex + 1, // plus one for the opening quote
keyLength, maskingState.getCurrentJsonPathNode());
stepOverWhitespaceCharacters(maskingState);
// step over the colon ':'
maskingState.next();
Expand All @@ -187,8 +177,7 @@ private void visitObject(MaskingState maskingState, @CheckForNull KeyMaskingConf
// we got was the default config, then it means that the key doesn't have a specific configuration and
// we should fall back to key specific config that the object is being masked with.
// E.g.: '{ "a": { "b": "value" } }' we want to use config of 'b' if any, but fallback to config of 'a'
if (parentKeyMaskingConfig != null && (keyMaskingConfig == null
|| keyMaskingConfig == maskingConfig.getDefaultConfig())) {
if (parentKeyMaskingConfig != null && (keyMaskingConfig == null || keyMaskingConfig == maskingConfig.getDefaultConfig())) {
keyMaskingConfig = parentKeyMaskingConfig;
}
visitValue(maskingState, keyMaskingConfig);
Expand Down Expand Up @@ -304,8 +293,7 @@ private static void stepOverWhitespaceCharacters(MaskingState maskingState) {
private static void stepOverNumericValue(MaskingState maskingState) {
// step over the first numeric character
maskingState.next();
while (maskingState.currentIndex() < maskingState.getMessage().length
&& AsciiJsonUtil.isNumericCharacter(maskingState.byteAtCurrentIndex())) {
while (maskingState.currentIndex() < maskingState.getMessage().length && AsciiJsonUtil.isNumericCharacter(maskingState.byteAtCurrentIndex())) {
maskingState.next();
}
}
Expand Down
90 changes: 39 additions & 51 deletions src/main/java/dev/blaauwendraad/masker/json/KeyMatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import javax.annotation.CheckForNull;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;

/**
* This key matcher is build using a byte trie structure to optimize the look-ups for JSON keys in the target key set.
Expand Down Expand Up @@ -133,14 +132,14 @@ private void insert(String word, boolean negativeMatch) {
* @return the config if the key needs to be masked, {@code null} if key does not need to be masked
*/
@CheckForNull
public KeyMaskingConfig getMaskConfigIfMatched(byte[] bytes, int keyOffset, int keyLength, Iterator<? extends JsonPathNode> jsonPath) {
public KeyMaskingConfig getMaskConfigIfMatched(byte[] bytes, int keyOffset, int keyLength, @CheckForNull TrieNode currentJsonPathNode) {
// first search by key
if (maskingConfig.isInMaskMode()) {
// check JSONPath first, as it's more specific
TrieNode node = searchForJsonPathKeyNode(bytes, jsonPath);
TrieNode node = currentJsonPathNode;
// if found - mask with this config
// if not found - do not mask
if (node != null && !node.negativeMatch) {
if (node != null && node.endOfWord && !node.negativeMatch) {
return node.keyMaskingConfig;
} else if (keyLength != SKIP_KEY_LOOKUP) {
// also check regular key
Expand All @@ -152,11 +151,11 @@ public KeyMaskingConfig getMaskConfigIfMatched(byte[] bytes, int keyOffset, int
return null;
} else {
// check JSONPath first, as it's more specific
TrieNode node = searchForJsonPathKeyNode(bytes, jsonPath);
TrieNode node = currentJsonPathNode;
// if found and is not negativeMatch - do not mask
// if found and is negative match - mask, but with a specific config
// if not found - mask with default config
if (node != null) {
if (node != null && node.endOfWord) {
if (node.negativeMatch) {
return node.keyMaskingConfig;
}
Expand Down Expand Up @@ -199,73 +198,62 @@ private TrieNode searchNode(byte[] bytes, int offset, int length) {
}

@CheckForNull
private TrieNode searchForJsonPathKeyNode(byte[] bytes, Iterator<? extends JsonPathNode> jsonPath) {
TrieNode node = root;
node = node.children['$' + BYTE_OFFSET];
if (node == null) {
public TrieNode getJsonPathRootNode() {
return root.children['$' + BYTE_OFFSET];
}

/**
* Traverses the trie along the passed JSONPath segment starting from {@code begin} node.
* The passed segment is represented as a key {@code (keyOffset, keyLength)} reference in {@code bytes} array.
*
* @param bytes the message bytes.
* @param begin a TrieNode from which the traversal begins.
* @param keyOffset the offset in {@code bytes} of the segment.
* @param keyLength the length of the segment.
* @return a TrieNode of the last symbol of the segment. {@code null} if the segment is not in the trie.
*/
@CheckForNull
public TrieNode traverseJsonPathSegment(byte[] bytes, @CheckForNull final TrieNode begin, int keyOffset, int keyLength) {
if (begin == null) {
return null;
}
if (node.endOfWord) {
return node;
TrieNode current = begin.children['.' + BYTE_OFFSET];
if (current == null) {
return null;
}
while (jsonPath.hasNext()) {
node = node.children['.' + BYTE_OFFSET];
if (node == null) {
return null;
}
JsonPathNode jsonPathSegmentReference = jsonPath.next();
TrieNode wildcardLookAhead = node.children['*' + BYTE_OFFSET];
if (wildcardLookAhead != null && (wildcardLookAhead.endOfWord || wildcardLookAhead.children['.' + BYTE_OFFSET] != null)) {
node = wildcardLookAhead;
if (node.endOfWord) {
return node;
}
continue;
}
if (jsonPathSegmentReference instanceof JsonPathNode.Node jsonPathNode) {
int keyOffset = jsonPathNode.getOffset();
int keyLength = jsonPathNode.getLength();
for (int i = keyOffset; i < keyOffset + keyLength; i++) {
int b = bytes[i];
node = node.children[b + BYTE_OFFSET];
if (node == null) {
return null;
}
}
} else if (jsonPathSegmentReference instanceof JsonPathNode.Array) {
// only wildcard indexes are supported
TrieNode wildcardLookAhead = current.children['*' + BYTE_OFFSET];
if (wildcardLookAhead != null && (wildcardLookAhead.endOfWord || wildcardLookAhead.children['.' + BYTE_OFFSET] != null)) {
return wildcardLookAhead;
}
for (int i = keyOffset; i < keyOffset + keyLength; i++) {
int b = bytes[i];
current = current.children[b + BYTE_OFFSET];
if (current == null) {
return null;
} else {
throw new IllegalStateException("Unknown JSONPath segment reference type " + jsonPathSegmentReference.getClass());
}
}

if (!node.endOfWord) {
return null;
}

return node;
return current;
}

/**
* A node in the Trie, represents part of the character (if character is ASCII, then represents a single character).
* A padding of 128 is used to store references to the next positive and negative bytes (which range from -128 to
* 128, hence the padding).
*/
private static class TrieNode {
private final TrieNode[] children = new TrieNode[256];
static class TrieNode {
final TrieNode[] children = new TrieNode[256];
/**
* A marker that the character indicates that the key ends at this node.
*/
private boolean endOfWord = false;
boolean endOfWord = false;
/**
* Masking configuration for the key that ends at this node.
*/
@CheckForNull
private KeyMaskingConfig keyMaskingConfig = null;
KeyMaskingConfig keyMaskingConfig = null;
/**
* Used to store the configuration, but indicate that json-masker is in ALLOW mode and the key is not allowed.
*/
private boolean negativeMatch = false;
boolean negativeMatch = false;
}
}
52 changes: 22 additions & 30 deletions src/main/java/dev/blaauwendraad/masker/json/MaskingState.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

import dev.blaauwendraad.masker.json.util.Utf8Util;

import javax.annotation.CheckForNull;
import java.nio.charset.StandardCharsets;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.Iterator;
import java.util.Arrays;
import java.util.List;

/**
Expand All @@ -21,18 +19,17 @@ final class MaskingState implements ValueMaskerContext {
private int replacementOperationsTotalDifference = 0;

/**
* Current JSONPath is represented by a dequeue of segment references.
* Current JSONPath is represented by a stack of segment references.
* A stack is implemented with an array of the trie nodes that reference the end of the segment
*/
private final Deque<JsonPathNode> currentJsonPath;

private KeyMatcher.TrieNode[] currentJsonPath = null;
private int currentJsonPathIndex = -1;
private int currentValueStartIndex = -1;

public MaskingState(byte[] message, boolean trackJsonPath) {
this.message = message;
if (trackJsonPath) {
currentJsonPath = new ArrayDeque<>();
} else {
currentJsonPath = null;
currentJsonPath = new KeyMatcher.TrieNode[100];
}
}

Expand Down Expand Up @@ -148,22 +145,17 @@ boolean jsonPathEnabled() {
}

/**
* Expands current jsonpath with a new "key" segment.
* @param start the index of a new segment start in <code>message</code>
* @param offset the length of a new segment.
*/
void expandCurrentJsonPath(int start, int offset) {
if (currentJsonPath != null) {
currentJsonPath.push(new JsonPathNode.Node(start, offset));
}
}

/**
* Expands current jsonpath with a new array segment.
* Expands current jsonpath.
*
* @param trieNode a node in the trie where the new segment ends.
*/
void expandCurrentJsonPathWithArray() {
void expandCurrentJsonPath(@CheckForNull KeyMatcher.TrieNode trieNode) {
if (currentJsonPath != null) {
currentJsonPath.push(new JsonPathNode.Array());
currentJsonPath[++currentJsonPathIndex] = trieNode;
if (currentJsonPathIndex == currentJsonPath.length - 1) {
// resize
currentJsonPath = Arrays.copyOf(currentJsonPath, currentJsonPath.length*2);
}
}
}

Expand All @@ -172,18 +164,18 @@ void expandCurrentJsonPathWithArray() {
*/
void backtrackCurrentJsonPath() {
if (currentJsonPath != null) {
currentJsonPath.pop();
currentJsonPath[currentJsonPathIndex--] = null;
}
}

/**
* Returns the iterator over the JSONPath component references from head to tail
* Returns the TrieNode that references the end of the latest segment in the current jsonpath
*/
Iterator<JsonPathNode> getCurrentJsonPath() {
if (currentJsonPath != null) {
return currentJsonPath.descendingIterator();
public KeyMatcher.TrieNode getCurrentJsonPathNode() {
if (currentJsonPath != null && currentJsonPathIndex != -1) {
return currentJsonPath[currentJsonPathIndex];
} else {
return Collections.emptyIterator();
return null;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,9 @@ public Builder allowJsonPaths(Set<String> jsonPaths) {
if (targetKeyMode == TargetKeyMode.MASK) {
throw new IllegalArgumentException("Cannot allow keys when in MASK mode");
}
if (jsonPaths.contains("$")) {
throw new IllegalArgumentException("Root node JSONPath is not allowed in ALLOW mode");
}
targetKeyMode = TargetKeyMode.ALLOW;
for (String jsonPath : jsonPaths) {
JsonPath parsed = JSON_PATH_PARSER.parse(jsonPath);
Expand Down
Loading

0 comments on commit 82c901b

Please sign in to comment.