Skip to content

Commit

Permalink
Tokenize the template string once to avoid quadratic runtime for craf…
Browse files Browse the repository at this point in the history
…ted sql template
  • Loading branch information
fluentfuture committed Nov 7, 2024
1 parent d10c886 commit b264e51
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 93 deletions.
52 changes: 23 additions & 29 deletions mug-guava/src/main/java/com/google/mu/safesql/SafeSql.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,13 @@
import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.logging.Logger;
import java.util.stream.Collector;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.errorprone.annotations.CompileTimeConstant;
import com.google.errorprone.annotations.MustBeClosed;
import com.google.mu.annotations.TemplateFormatMethod;
Expand Down Expand Up @@ -254,15 +252,15 @@ public final class SafeSql {

/** An empty SQL */
public static final SafeSql EMPTY = new SafeSql("");
private static final Substring.RepeatingPattern TOKENS =
Stream.of(word(), first(breakingWhitespace().negate()::matches))
.collect(firstOccurrence())
.repeatedly();
private static final SafeSql FALSE = new SafeSql("(1 = 0)");
private static final SafeSql TRUE = new SafeSql("(1 = 1)");
private static final StringFormat.Template<SafeSql> PARAM = template("{param}");
private static final StringFormat PLACEHOLDER_ELEMENT_NAME =
new StringFormat("{placeholder}[{index}]");
private static final Substring.RepeatingPattern TOKENS =
Stream.of(word(), first(breakingWhitespace().negate()::matches))
.collect(firstOccurrence())
.repeatedly();

private final String sql;
private final List<?> paramValues;
Expand Down Expand Up @@ -389,7 +387,11 @@ public static SafeSql when(
* <p>The returned template is immutable and thread safe.
*/
public static Template<SafeSql> template(@CompileTimeConstant String template) {
ConcurrentMap<String, Boolean> placeholderSurroundings = new ConcurrentHashMap<>();
ImmutableList<Substring.Match> allTokens = TOKENS.match(template).collect(toImmutableList());
ImmutableMap<Integer, Integer> charIndexToTokenIndex =
BiStream.zip(allTokens.stream(), indexesFrom(0))
.mapKeys(Substring.Match::index)
.collect(ImmutableMap::toImmutableMap);
return StringFormat.template(template, (fragments, placeholders) -> {
Deque<String> texts = new ArrayDeque<>(fragments);
Builder builder = new Builder();
Expand Down Expand Up @@ -470,9 +472,19 @@ private boolean appendBeforeQuotedPlaceholder(

private boolean lookaround(
String leftPattern, Substring.Match placeholder, String rightPattern) {
return placeholderSurroundings.computeIfAbsent(
leftPattern + "{" + placeholder.index() + "}" + rightPattern,
k -> matchesPattern(leftPattern, placeholder, rightPattern));
ImmutableList<String> lookbehind = TOKENS.from(leftPattern).collect(toImmutableList());
ImmutableList<String> lookahead = TOKENS.from(rightPattern).collect(toImmutableList());
ImmutableList<Substring.Match> leftTokens = allTokens.subList(
0, charIndexToTokenIndex.get(placeholder.index()));
ImmutableList<Substring.Match> rightTokens = allTokens.subList(
charIndexToTokenIndex.get(placeholder.index() + placeholder.length() - 1) + 1,
allTokens.size());
return BiStream.zip(lookbehind.reverse(), leftTokens.reverse()) // right-to-left
.filter((s, t) -> s.equalsIgnoreCase(t.toString()))
.count() == lookbehind.size()
&& BiStream.zip(lookahead, rightTokens)
.filter((s, t) -> s.equalsIgnoreCase(t.toString()))
.count() == lookahead.size();
}
}
placeholders.forEachOrdered(new SqlWriter()::writePlaceholder);
Expand Down Expand Up @@ -763,24 +775,6 @@ private static SafeSql subqueryOrParameter(CharSequence name, Object param) {
.mapKeys(index -> PLACEHOLDER_ELEMENT_NAME.format(placeholder, index));
}

@VisibleForTesting
static boolean matchesPattern(String left, Substring.Match placeholder, String right) {
ImmutableList<String> leftTokensToMatch = TOKENS.from(left).collect(toImmutableList());
ImmutableList<String> rightTokensToMatch = TOKENS.from(right).collect(toImmutableList());
// Matches right side first because we can lazily scan the right side without copying
return BiStream.zip(
rightTokensToMatch.stream(),
TOKENS.match(placeholder.fullString(), placeholder.index() + placeholder.length())
.map(Object::toString))
.filter(String::equalsIgnoreCase)
.count() == rightTokensToMatch.size()
&& BiStream.zip(
leftTokensToMatch.reverse(),
TOKENS.from(placeholder.before()).collect(toImmutableList()).reverse())
.filter(String::equalsIgnoreCase)
.count() == leftTokensToMatch.size();
}

private static String escapePercent(String s) {
return first(c -> c == '\\' || c == '%').repeatedly().replaceAllFrom(s, c -> "\\" + c);
}
Expand Down
64 changes: 0 additions & 64 deletions mug-guava/src/test/java/com/google/mu/safesql/SafeSqlTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import com.google.common.collect.ImmutableList;
import com.google.common.testing.EqualsTester;
import com.google.common.testing.NullPointerTester;
import com.google.mu.util.Substring;
import com.google.testing.junit.testparameterinjector.TestParameterInjector;

@RunWith(TestParameterInjector.class)
Expand Down Expand Up @@ -841,69 +840,6 @@ public void nonNegative_minValue_throws() {
assertThat(thrown).hasMessageThat().contains("negative number disallowed");
}

@Test
public void matchesPattern_matchesWithoutWhitespaces() {
Substring.Match placeholder = Substring.spanningInOrder("{", "}").in("IN({ids})").get();
assertThat(SafeSql.matchesPattern("IN(", placeholder, ")")).isTrue();
assertThat(SafeSql.matchesPattern("IN (", placeholder, ")")).isTrue();
}

@Test
public void matchesPattern_matchesWithWhitespaces() {
Substring.Match placeholder = Substring.spanningInOrder("{", "}").in("id IN (\n {ids} \n)").get();
assertThat(SafeSql.matchesPattern("IN(", placeholder, ")")).isTrue();
}

@Test
public void matchesPattern_matchesIgnoreCase() {
Substring.Match placeholder =
Substring.spanningInOrder("{", "}").in("id not in (\n {ids} ) or {...}").get();
assertThat(SafeSql.matchesPattern("IN(", placeholder, ")")).isTrue();
assertThat(SafeSql.matchesPattern("In (", placeholder, " ) ")).isTrue();
}

@Test
public void matchesPattern_leftDoesNotMatch() {
Substring.Match placeholder =
Substring.spanningInOrder("{", "}").in("min({ids})").get();
assertThat(SafeSql.matchesPattern("in(", placeholder, ")")).isFalse();
}

@Test
public void matchesPattern_rightDoesNotMatch() {
Substring.Match placeholder =
Substring.spanningInOrder("{", "}").in("in({ids},)").get();
assertThat(SafeSql.matchesPattern("in(", placeholder, ")")).isFalse();
}

@Test
public void matchesPattern_leftSideNone() {
Substring.Match placeholder =
Substring.spanningInOrder("{", "}").in("{ids}) or true").get();
assertThat(SafeSql.matchesPattern("in(", placeholder, ")")).isFalse();
}

@Test
public void matchesPattern_rightSideNone() {
Substring.Match placeholder =
Substring.spanningInOrder("{", "}").in("in ({ids}").get();
assertThat(SafeSql.matchesPattern("in(", placeholder, ")")).isFalse();
}

@Test
public void matchesPattern_neitherSide() {
Substring.Match placeholder =
Substring.spanningInOrder("{", "}").in("{ids}").get();
assertThat(SafeSql.matchesPattern("in(", placeholder, ")")).isFalse();
}

@Test
public void matchesPattern_matchesWithQuotes() {
Substring.Match placeholder =
Substring.spanningInOrder("{", "}").in("in ('{ids}')").get();
assertThat(SafeSql.matchesPattern("in('", placeholder, "')")).isTrue();
}

@Test
public void testEquals() {
new EqualsTester()
Expand Down

0 comments on commit b264e51

Please sign in to comment.