Skip to content

Commit

Permalink
Add query context
Browse files Browse the repository at this point in the history
  • Loading branch information
tsmacdonald committed Apr 23, 2024
1 parent 05a21d5 commit ceb7532
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 33 deletions.
81 changes: 74 additions & 7 deletions java/com/metabase/macaw/AstWalker.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import clojure.lang.IFn;
import clojure.lang.Keyword;

import java.util.ArrayDeque;
import java.util.Deque;
import java.util.EnumMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -171,12 +173,9 @@
import net.sf.jsqlparser.statement.update.Update;
import net.sf.jsqlparser.statement.upsert.Upsert;

import static com.metabase.macaw.AstWalker.CallbackKey.ALL_COLUMNS;
import static com.metabase.macaw.AstWalker.CallbackKey.ALL_TABLE_COLUMNS;
import static com.metabase.macaw.AstWalker.CallbackKey.COLUMN;
import static com.metabase.macaw.AstWalker.CallbackKey.MUTATION_COMMAND;
import static com.metabase.macaw.AstWalker.CallbackKey.TABLE;
import static com.metabase.macaw.AstWalker.CallbackKey.TABLE_ALIAS;
import static com.metabase.macaw.AstWalker.CallbackKey.*;

import static com.metabase.macaw.AstWalker.QueryContext.*;

/**
* Walks the AST, using JSqlParser's `visit()` methods. Each `visit()` method additionally calls an applicable callback
Expand Down Expand Up @@ -229,10 +228,30 @@ public String toString() {
}
}

public enum QueryContext {
DELETE,
ELSE,
FROM,
GROUP_BY,
HAVING,
IF,
INSERT,
JOIN,
SELECT,
SUB_SELECT,
UPDATE,
WHERE;

public String toString() {
return name().toUpperCase();
}
}

private static final String NOT_SUPPORTED_YET = "Not supported yet.";

private Acc acc;
private final EnumMap<CallbackKey, IFn> callbacks;
private final Deque<String> contextStack;

/**
* Construct a new walker with the given `callbacks`. The `callbacks` should be a (Clojure) map of CallbackKeys to
Expand All @@ -243,6 +262,7 @@ public String toString() {
public AstWalker(Map<CallbackKey, IFn> rawCallbacks, Acc val) {
this.acc = val;
this.callbacks = new EnumMap<>(rawCallbacks);
this.contextStack = new ArrayDeque<String>();
}

/**
Expand All @@ -252,10 +272,19 @@ public void invokeCallback(CallbackKey key, Object visitedItem) {
IFn callback = this.callbacks.get(key);
if (callback != null) {
//noinspection unchecked
this.acc = (Acc) callback.invoke(acc, visitedItem);
this.acc = (Acc) callback.invoke(acc, visitedItem, this.contextStack.toArray());
}
}

private void pushContext(QueryContext c) {
this.contextStack.push(c.toString());
}

// This is pure sugar, but it's nice to be symmetrical with pushContext
private void popContext() {
this.contextStack.pop();
}

/**
* Fold the given `statement`, using the callbacks to update the accumulator as appropriate.
*/
Expand All @@ -274,6 +303,7 @@ public Expression walk(Expression expression) {

@Override
public void visit(Select select) {
// No pushContext(SELECT) since it's handled by the ParenthesedSelect and PlainSelect methods
List<WithItem> withItemsList = select.getWithItemsList();
if (withItemsList != null && !withItemsList.isEmpty()) {
for (WithItem withItem : withItemsList) {
Expand All @@ -294,7 +324,9 @@ public void visit(TrimFunction trimFunction) {
trimFunction.getExpression().accept(this);
}
if (trimFunction.getFromExpression() != null) {
pushContext(FROM);
trimFunction.getFromExpression().accept(this);
popContext(); // FROM
}
}

Expand All @@ -311,17 +343,20 @@ public void visit(WithItem withItem) {

@Override
public void visit(ParenthesedSelect selectBody) {
pushContext(SUB_SELECT);
List<WithItem> withItemsList = selectBody.getWithItemsList();
if (withItemsList != null && !withItemsList.isEmpty()) {
for (WithItem withItem : withItemsList) {
withItem.accept((SelectVisitor) this);
}
}
selectBody.getSelect().accept((SelectVisitor) this);
popContext(); // SUB_SELECT
}

@Override
public void visit(PlainSelect plainSelect) {
pushContext(SELECT);
List<WithItem> withItemsList = plainSelect.getWithItemsList();
if (withItemsList != null && !withItemsList.isEmpty()) {
for (WithItem withItem : withItemsList) {
Expand All @@ -335,25 +370,33 @@ public void visit(PlainSelect plainSelect) {
}

if (plainSelect.getFromItem() != null) {
pushContext(FROM);
plainSelect.getFromItem().accept(this);
popContext(); // FROM
}

visitJoins(plainSelect.getJoins());
if (plainSelect.getWhere() != null) {
pushContext(WHERE);
plainSelect.getWhere().accept(this);
popContext(); // WHERE
}

if (plainSelect.getHaving() != null) {
pushContext(HAVING);
plainSelect.getHaving().accept(this);
popContext(); // HAVING
}

if (plainSelect.getOracleHierarchical() != null) {
plainSelect.getOracleHierarchical().accept(this);
}

if (plainSelect.getGroupBy() != null) {
// contextStack handled in visit()
plainSelect.getGroupBy().accept(this);
}
popContext(); // SELECT
}

@Override
Expand Down Expand Up @@ -818,6 +861,7 @@ public void visit(MySQLGroupConcat groupConcat) {

@Override
public void visit(Delete delete) {
pushContext(DELETE);
invokeCallback(MUTATION_COMMAND, "delete");
visit(delete.getTable());

Expand All @@ -830,12 +874,16 @@ public void visit(Delete delete) {
visitJoins(delete.getJoins());

if (delete.getWhere() != null) {
pushContext(WHERE);
delete.getWhere().accept(this);
popContext(); // WHERE
}
popContext(); // DELETE
}

@Override
public void visit(Update update) {
pushContext(UPDATE);
invokeCallback(MUTATION_COMMAND, "update");
visit(update.getTable());
if (update.getWithItemsList() != null) {
Expand All @@ -856,7 +904,9 @@ public void visit(Update update) {
}

if (update.getFromItem() != null) {
pushContext(FROM);
update.getFromItem().accept(this);
popContext(); // FROM
}

if (update.getJoins() != null) {
Expand All @@ -869,12 +919,16 @@ public void visit(Update update) {
}

if (update.getWhere() != null) {
pushContext(WHERE);
update.getWhere().accept(this);
popContext(); // WHERE
}
popContext(); // UPDATE
}

@Override
public void visit(Insert insert) {
pushContext(INSERT);
invokeCallback(MUTATION_COMMAND, "insert");
visit(insert.getTable());
if (insert.getWithItemsList() != null) {
Expand All @@ -885,6 +939,7 @@ public void visit(Insert insert) {
if (insert.getSelect() != null) {
visit(insert.getSelect());
}
popContext(); // INSERT
}

public void visit(Analyze analyze) {
Expand Down Expand Up @@ -989,7 +1044,9 @@ public void visit(Merge merge) {
}

if (merge.getFromItem() != null) {
pushContext(FROM);
merge.getFromItem().accept(this);
popContext(); // FROM
}
}

Expand Down Expand Up @@ -1053,10 +1110,12 @@ public void visit(ParenthesedFromItem parenthesis) {

@Override
public void visit(GroupByElement element) {
pushContext(GROUP_BY);
element.getGroupByExpressionList().accept(this);
for (ExpressionList exprList : element.getGroupingSets()) {
exprList.accept(this);
}
popContext(); // GROUP_BY
}

/**
Expand All @@ -1065,16 +1124,20 @@ public void visit(GroupByElement element) {
* @param parenthesis join sql block
*/
private void visitJoins(List<Join> parenthesis) {
pushContext(JOIN);
if (parenthesis == null) {
return;
}
for (Join join : parenthesis) {
pushContext(FROM);
join.getFromItem().accept(this);
popContext(); // FROM
join.getRightItem().accept(this);
for (Expression expression : join.getOnExpressions()) {
expression.accept(this);
}
}
popContext(); // JOIN
}

@Override
Expand Down Expand Up @@ -1256,9 +1319,13 @@ public void visit(ConnectByRootOperator connectByRootOperator) {
}

public void visit(IfElseStatement ifElseStatement) {
pushContext(IF);
ifElseStatement.getIfStatement().accept(this);
popContext(); // IF
if (ifElseStatement.getElseStatement() != null) {
pushContext(ELSE);
ifElseStatement.getElseStatement().accept(this);
popContext(); // ELSE
}
}

Expand Down
28 changes: 17 additions & 11 deletions src/macaw/core.clj
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,23 @@
(set! *warn-on-reflection* true)

(defn- conj-to
[key-name]
(fn item-conjer [results item]
(update results key-name conj item)))
([key-name]
(conj-to key-name identity))
([key-name xf]
(fn item-conjer [results component context]
(update results key-name conj {:component (xf component)
:context (vec context)}))))

(defn- query->raw-components
[^Statement parsed-query]
(mw/fold-query parsed-query
{:column (conj-to :columns)
:mutation (conj-to :mutation-commands)
:wildcard (fn [results _all-columns]
(assoc results :has-wildcard? true))
:wildcard (conj-to :has-wildcard? (constantly true))
:table (conj-to :tables)
:table-wildcard (conj-to :table-wildcards)}
{:columns #{}
:has-wildcard? false
:has-wildcard? #{}
:mutation-commands #{}
:tables #{}
:table-wildcards #{}}))
Expand All @@ -45,6 +47,10 @@
(or (alias->name table-name)
table-name)))

(defn- update-components
[f components]
(map #(update % :component f) components))

(defn query->components
"Given a parsed query (i.e., a [subclass of] `Statement`) return a map with the elements found within it.
Expand All @@ -54,12 +60,12 @@
(let [{:keys [columns has-wildcard?
mutation-commands
tables table-wildcards]} (query->raw-components parsed-query)
aliases (into {} (map alias-mapping tables))]
{:columns (into #{} (map #(.getColumnName ^Column %) columns))
:has-wildcard? has-wildcard?
aliases (into {} (map (comp alias-mapping :component) tables))]
{:columns (into #{} (update-components #(.getColumnName ^Column %) columns))
:has-wildcard? (into #{} has-wildcard?)
:mutation-commands (into #{} mutation-commands)
:tables (into #{} (map #(.getName ^Table %) tables))
:table-wildcards (into #{} (map (partial resolve-table-name aliases) table-wildcards))}))
:tables (into #{} (update-components #(.getName ^Table %) tables))
:table-wildcards (into #{} (update-components (partial resolve-table-name aliases) table-wildcards))}))

(defn parsed-query
"Main entry point: takes a string query and returns a `Statement` object that can be handled by the other functions."
Expand Down
8 changes: 4 additions & 4 deletions src/macaw/rewrite.clj
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"Emit a SQL string for an updated AST, preserving the comments and whitespace from the original SQL."
[updated-ast sql]
(let [replace-name (fn [->s]
(fn [acc ^ASTNodeAccess visitable]
(fn [acc ^ASTNodeAccess visitable _ctx]
(let [node (.getASTNode visitable)]
;; not sure why sometimes we get a phantom visitable without an underlying node
(if (nil? node)
Expand All @@ -67,7 +67,7 @@
[]))))

(defn- rename-table
[table-renames ^Table table]
[table-renames ^Table table _ctx]
(when-let [name' (get table-renames (.getName table))]
(.setName table name')))

Expand All @@ -78,6 +78,6 @@
(mw/walk-query
{:table (partial rename-table table-renames)
:table-alias (partial rename-table table-renames)
:column (fn [^Column column] (when-let [name' (get column-renames (.getColumnName column))]
(.setColumnName column name')))})
:column (fn [^Column column _ctx] (when-let [name' (get column-renames (.getColumnName column))]
(.setColumnName column name')))})
(update-query sql)))
12 changes: 7 additions & 5 deletions src/macaw/walk.clj
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,23 @@
(defn- preserve
"Lift a side effecting callback so that it preserves the accumulator."
[f]
(fn [acc v]
(f v)
(fn [acc & args]
(apply f args)
acc))

;; work around ast walker repeatedly visiting the same expressions (bug ?!)
(defn- deduplicate-visits [f]
(let [seen (volatile! #{})]
(fn [acc visitable]
(fn [& [acc visitable & _ :as args]]
(if (contains? @seen visitable)
acc
(do (vswap! seen conj visitable)
(f acc visitable))))))
(apply f args))))))

(defn- update-keys-vals [m key-f val-f]
(into {} (map (fn [[k v]] [(key-f k) (val-f v)])) m))
(into {} (map (fn [[k v]]
[(key-f k) (val-f v)]))
m))

(defn walk-query
"Walk over the query's AST, using the callbacks for their side-effects, for example to mutate the AST itself."
Expand Down
Loading

0 comments on commit ceb7532

Please sign in to comment.