From b264e514e257ea1348e3ebf46fad799b2f338f6f Mon Sep 17 00:00:00 2001 From: Ben Yu Date: Wed, 6 Nov 2024 22:39:55 -0800 Subject: [PATCH] Tokenize the template string once to avoid quadratic runtime for crafted sql template --- .../java/com/google/mu/safesql/SafeSql.java | 52 +++++++-------- .../com/google/mu/safesql/SafeSqlTest.java | 64 ------------------- 2 files changed, 23 insertions(+), 93 deletions(-) diff --git a/mug-guava/src/main/java/com/google/mu/safesql/SafeSql.java b/mug-guava/src/main/java/com/google/mu/safesql/SafeSql.java index ba8fbf5f71..73826e4947 100644 --- a/mug-guava/src/main/java/com/google/mu/safesql/SafeSql.java +++ b/mug-guava/src/main/java/com/google/mu/safesql/SafeSql.java @@ -45,15 +45,13 @@ import java.util.Iterator; import java.util.List; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.logging.Logger; import java.util.stream.Collector; import java.util.stream.Collectors; import java.util.stream.Stream; -import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.errorprone.annotations.CompileTimeConstant; import com.google.errorprone.annotations.MustBeClosed; import com.google.mu.annotations.TemplateFormatMethod; @@ -254,15 +252,15 @@ public final class SafeSql { /** An empty SQL */ public static final SafeSql EMPTY = new SafeSql(""); + private static final Substring.RepeatingPattern TOKENS = + Stream.of(word(), first(breakingWhitespace().negate()::matches)) + .collect(firstOccurrence()) + .repeatedly(); private static final SafeSql FALSE = new SafeSql("(1 = 0)"); private static final SafeSql TRUE = new SafeSql("(1 = 1)"); private static final StringFormat.Template PARAM = template("{param}"); private static final StringFormat PLACEHOLDER_ELEMENT_NAME = new StringFormat("{placeholder}[{index}]"); - private static final Substring.RepeatingPattern TOKENS = - Stream.of(word(), first(breakingWhitespace().negate()::matches)) - .collect(firstOccurrence()) - .repeatedly(); private final String sql; private final List paramValues; @@ -389,7 +387,11 @@ public static SafeSql when( *

The returned template is immutable and thread safe. */ public static Template template(@CompileTimeConstant String template) { - ConcurrentMap placeholderSurroundings = new ConcurrentHashMap<>(); + ImmutableList allTokens = TOKENS.match(template).collect(toImmutableList()); + ImmutableMap charIndexToTokenIndex = + BiStream.zip(allTokens.stream(), indexesFrom(0)) + .mapKeys(Substring.Match::index) + .collect(ImmutableMap::toImmutableMap); return StringFormat.template(template, (fragments, placeholders) -> { Deque texts = new ArrayDeque<>(fragments); Builder builder = new Builder(); @@ -470,9 +472,19 @@ private boolean appendBeforeQuotedPlaceholder( private boolean lookaround( String leftPattern, Substring.Match placeholder, String rightPattern) { - return placeholderSurroundings.computeIfAbsent( - leftPattern + "{" + placeholder.index() + "}" + rightPattern, - k -> matchesPattern(leftPattern, placeholder, rightPattern)); + ImmutableList lookbehind = TOKENS.from(leftPattern).collect(toImmutableList()); + ImmutableList lookahead = TOKENS.from(rightPattern).collect(toImmutableList()); + ImmutableList leftTokens = allTokens.subList( + 0, charIndexToTokenIndex.get(placeholder.index())); + ImmutableList rightTokens = allTokens.subList( + charIndexToTokenIndex.get(placeholder.index() + placeholder.length() - 1) + 1, + allTokens.size()); + return BiStream.zip(lookbehind.reverse(), leftTokens.reverse()) // right-to-left + .filter((s, t) -> s.equalsIgnoreCase(t.toString())) + .count() == lookbehind.size() + && BiStream.zip(lookahead, rightTokens) + .filter((s, t) -> s.equalsIgnoreCase(t.toString())) + .count() == lookahead.size(); } } placeholders.forEachOrdered(new SqlWriter()::writePlaceholder); @@ -763,24 +775,6 @@ private static SafeSql subqueryOrParameter(CharSequence name, Object param) { .mapKeys(index -> PLACEHOLDER_ELEMENT_NAME.format(placeholder, index)); } - @VisibleForTesting - static boolean matchesPattern(String left, Substring.Match placeholder, String right) { - ImmutableList leftTokensToMatch = TOKENS.from(left).collect(toImmutableList()); - ImmutableList rightTokensToMatch = TOKENS.from(right).collect(toImmutableList()); - // Matches right side first because we can lazily scan the right side without copying - return BiStream.zip( - rightTokensToMatch.stream(), - TOKENS.match(placeholder.fullString(), placeholder.index() + placeholder.length()) - .map(Object::toString)) - .filter(String::equalsIgnoreCase) - .count() == rightTokensToMatch.size() - && BiStream.zip( - leftTokensToMatch.reverse(), - TOKENS.from(placeholder.before()).collect(toImmutableList()).reverse()) - .filter(String::equalsIgnoreCase) - .count() == leftTokensToMatch.size(); - } - private static String escapePercent(String s) { return first(c -> c == '\\' || c == '%').repeatedly().replaceAllFrom(s, c -> "\\" + c); } diff --git a/mug-guava/src/test/java/com/google/mu/safesql/SafeSqlTest.java b/mug-guava/src/test/java/com/google/mu/safesql/SafeSqlTest.java index 685ce09f1a..f5805d0f48 100644 --- a/mug-guava/src/test/java/com/google/mu/safesql/SafeSqlTest.java +++ b/mug-guava/src/test/java/com/google/mu/safesql/SafeSqlTest.java @@ -15,7 +15,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.testing.EqualsTester; import com.google.common.testing.NullPointerTester; -import com.google.mu.util.Substring; import com.google.testing.junit.testparameterinjector.TestParameterInjector; @RunWith(TestParameterInjector.class) @@ -841,69 +840,6 @@ public void nonNegative_minValue_throws() { assertThat(thrown).hasMessageThat().contains("negative number disallowed"); } - @Test - public void matchesPattern_matchesWithoutWhitespaces() { - Substring.Match placeholder = Substring.spanningInOrder("{", "}").in("IN({ids})").get(); - assertThat(SafeSql.matchesPattern("IN(", placeholder, ")")).isTrue(); - assertThat(SafeSql.matchesPattern("IN (", placeholder, ")")).isTrue(); - } - - @Test - public void matchesPattern_matchesWithWhitespaces() { - Substring.Match placeholder = Substring.spanningInOrder("{", "}").in("id IN (\n {ids} \n)").get(); - assertThat(SafeSql.matchesPattern("IN(", placeholder, ")")).isTrue(); - } - - @Test - public void matchesPattern_matchesIgnoreCase() { - Substring.Match placeholder = - Substring.spanningInOrder("{", "}").in("id not in (\n {ids} ) or {...}").get(); - assertThat(SafeSql.matchesPattern("IN(", placeholder, ")")).isTrue(); - assertThat(SafeSql.matchesPattern("In (", placeholder, " ) ")).isTrue(); - } - - @Test - public void matchesPattern_leftDoesNotMatch() { - Substring.Match placeholder = - Substring.spanningInOrder("{", "}").in("min({ids})").get(); - assertThat(SafeSql.matchesPattern("in(", placeholder, ")")).isFalse(); - } - - @Test - public void matchesPattern_rightDoesNotMatch() { - Substring.Match placeholder = - Substring.spanningInOrder("{", "}").in("in({ids},)").get(); - assertThat(SafeSql.matchesPattern("in(", placeholder, ")")).isFalse(); - } - - @Test - public void matchesPattern_leftSideNone() { - Substring.Match placeholder = - Substring.spanningInOrder("{", "}").in("{ids}) or true").get(); - assertThat(SafeSql.matchesPattern("in(", placeholder, ")")).isFalse(); - } - - @Test - public void matchesPattern_rightSideNone() { - Substring.Match placeholder = - Substring.spanningInOrder("{", "}").in("in ({ids}").get(); - assertThat(SafeSql.matchesPattern("in(", placeholder, ")")).isFalse(); - } - - @Test - public void matchesPattern_neitherSide() { - Substring.Match placeholder = - Substring.spanningInOrder("{", "}").in("{ids}").get(); - assertThat(SafeSql.matchesPattern("in(", placeholder, ")")).isFalse(); - } - - @Test - public void matchesPattern_matchesWithQuotes() { - Substring.Match placeholder = - Substring.spanningInOrder("{", "}").in("in ('{ids}')").get(); - assertThat(SafeSql.matchesPattern("in('", placeholder, "')")).isTrue(); - } - @Test public void testEquals() { new EqualsTester()