Skip to content

Commit

Permalink
Catch common mutation-commands
Browse files Browse the repository at this point in the history
  • Loading branch information
tsmacdonald committed Apr 8, 2024
1 parent 6a6ca98 commit 4a66e55
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 32 deletions.
57 changes: 36 additions & 21 deletions java/com/metabase/macaw/AstWalker.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
import net.sf.jsqlparser.statement.ShowColumnsStatement;
import net.sf.jsqlparser.statement.ShowStatement;
import net.sf.jsqlparser.statement.StatementVisitor;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.Statements;
import net.sf.jsqlparser.statement.UnsupportedStatement;
import net.sf.jsqlparser.statement.UseStatement;
Expand Down Expand Up @@ -173,6 +174,7 @@
import static com.metabase.macaw.AstWalker.CallbackKey.ALL_COLUMNS;
import static com.metabase.macaw.AstWalker.CallbackKey.ALL_TABLE_COLUMNS;
import static com.metabase.macaw.AstWalker.CallbackKey.COLUMN;
import static com.metabase.macaw.AstWalker.CallbackKey.MUTATION_COMMAND;
import static com.metabase.macaw.AstWalker.CallbackKey.TABLE;

/**
Expand Down Expand Up @@ -217,6 +219,7 @@ public enum CallbackKey {
ALL_COLUMNS,
ALL_TABLE_COLUMNS,
COLUMN,
MUTATION_COMMAND,
TABLE;

public String toString() {
Expand Down Expand Up @@ -247,16 +250,24 @@ public void invokeCallback(CallbackKey key, Object visitedItem) {
IFn callback = this.callbacks.get(key);
if (callback != null) {
//noinspection unchecked
acc = (Acc) callback.invoke(acc, visitedItem);
this.acc = (Acc) callback.invoke(acc, visitedItem);
}
}

/**
* Fold the given `expression`, using the callbacks to update the accumulator as appropriate.
* Fold the given `expressionOrStatement`, using the callbacks to update the accumulator as appropriate.
*/
public Acc fold(Expression expression) {
expression.accept(this);
return acc;
public Acc fold(Object expressionOrStatement) {
if (expressionOrStatement instanceof Expression) {
((Expression)expressionOrStatement).accept(this);
}
else if (expressionOrStatement instanceof Statement) {
((Statement)expressionOrStatement).accept(this);
}
else {
throw new IllegalArgumentException("`expressionOrStatement` is neither an Expression nor a Statement");
}
return this.acc;
}

/**
Expand Down Expand Up @@ -819,6 +830,7 @@ public void visit(MySQLGroupConcat groupConcat) {

@Override
public void visit(Delete delete) {
invokeCallback(MUTATION_COMMAND, "delete");
visit(delete.getTable());

if (delete.getUsingList() != null) {
Expand All @@ -836,6 +848,7 @@ public void visit(Delete delete) {

@Override
public void visit(Update update) {
invokeCallback(MUTATION_COMMAND, "update");
visit(update.getTable());
if (update.getWithItemsList() != null) {
for (WithItem withItem : update.getWithItemsList()) {
Expand Down Expand Up @@ -874,6 +887,7 @@ public void visit(Update update) {

@Override
public void visit(Insert insert) {
invokeCallback(MUTATION_COMMAND, "insert");
visit(insert.getTable());
if (insert.getWithItemsList() != null) {
for (WithItem withItem : insert.getWithItemsList()) {
Expand All @@ -891,26 +905,29 @@ public void visit(Analyze analyze) {

@Override
public void visit(Drop drop) {
invokeCallback(MUTATION_COMMAND, "drop");
visit(drop.getName());
}

@Override
public void visit(Truncate truncate) {
invokeCallback(MUTATION_COMMAND, "truncate");
visit(truncate.getTable());
}

@Override
public void visit(CreateIndex createIndex) {
throw new UnsupportedOperationException(NOT_SUPPORTED_YET);
invokeCallback(MUTATION_COMMAND, "create-index");
}

@Override
public void visit(CreateSchema aThis) {
throw new UnsupportedOperationException(NOT_SUPPORTED_YET);
invokeCallback(MUTATION_COMMAND, "create-schema");
}

@Override
public void visit(CreateTable create) {
invokeCallback(MUTATION_COMMAND, "create-table");
visit(create.getTable());
if (create.getSelect() != null) {
create.getSelect().accept((SelectVisitor) this);
Expand All @@ -919,12 +936,12 @@ public void visit(CreateTable create) {

@Override
public void visit(CreateView createView) {
throw new UnsupportedOperationException(NOT_SUPPORTED_YET);
invokeCallback(MUTATION_COMMAND, "create-view");
}

@Override
public void visit(Alter alter) {
throw new UnsupportedOperationException(NOT_SUPPORTED_YET);
invokeCallback(MUTATION_COMMAND, "alter-table");
}

@Override
Expand Down Expand Up @@ -1000,7 +1017,7 @@ public void visit(TableFunction tableFunction) {

@Override
public void visit(AlterView alterView) {
throw new UnsupportedOperationException(NOT_SUPPORTED_YET);
invokeCallback(MUTATION_COMMAND, "alter-view");
}

@Override
Expand Down Expand Up @@ -1136,8 +1153,7 @@ public void visit(DeclareStatement aThis) {

@Override
public void visit(Grant grant) {


invokeCallback(MUTATION_COMMAND, "grant");
}

@Override
Expand All @@ -1163,20 +1179,17 @@ public void visit(ArrayConstructor array) {

@Override
public void visit(CreateSequence createSequence) {
throw new UnsupportedOperationException(
"Reading from a CreateSequence is not supported");
invokeCallback(MUTATION_COMMAND, "create-sequence");
}

@Override
public void visit(AlterSequence alterSequence) {
throw new UnsupportedOperationException(
"Reading from an AlterSequence is not supported");
invokeCallback(MUTATION_COMMAND, "alter-sequence");
}

@Override
public void visit(CreateFunctionalStatement createFunctionalStatement) {
throw new UnsupportedOperationException(
"Reading from a CreateFunctionalStatement is not supported");
invokeCallback(MUTATION_COMMAND, "create-function");
}

@Override
Expand Down Expand Up @@ -1208,8 +1221,7 @@ public void visit(XMLSerializeExpr aThis) {

@Override
public void visit(CreateSynonym createSynonym) {
throw new UnsupportedOperationException(
"Reading from a CreateSynonym is not supported");
invokeCallback(MUTATION_COMMAND, "create-synonym");
}

@Override
Expand All @@ -1227,7 +1239,7 @@ public void visit(RollbackStatement rollbackStatement) {

@Override
public void visit(AlterSession alterSession) {

invokeCallback(MUTATION_COMMAND, "alter-session");
}

@Override
Expand Down Expand Up @@ -1268,6 +1280,7 @@ public void visit(OracleNamedFunctionParameter oracleNamedFunctionParameter) {

@Override
public void visit(RenameTableStatement renameTableStatement) {
invokeCallback(MUTATION_COMMAND, "rename-table");
for (Map.Entry<Table, Table> e : renameTableStatement.getTableNames()) {
e.getKey().accept(this);
e.getValue().accept(this);
Expand All @@ -1276,13 +1289,15 @@ public void visit(RenameTableStatement renameTableStatement) {

@Override
public void visit(PurgeStatement purgeStatement) {
invokeCallback(MUTATION_COMMAND, "purge");
if (purgeStatement.getPurgeObjectType() == PurgeObjectType.TABLE) {
((Table) purgeStatement.getObject()).accept(this);
}
}

@Override
public void visit(AlterSystemStatement alterSystemStatement) {
invokeCallback(MUTATION_COMMAND, "alter-system");
}

@Override
Expand Down
24 changes: 14 additions & 10 deletions src/macaw/core.clj
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,16 @@
[^Statement parsed-query]
(mw/fold-query parsed-query
{:column (conj-to :columns)
:mutation (conj-to :mutation-commands)
:wildcard (fn [results _all-columns]
(assoc results :has-wildcard? true))
:table (conj-to :tables)
:table-wildcard (conj-to :table-wildcards)}
{:columns #{}
:has-wildcard? false
:tables #{}
:table-wildcards #{}}))
{:columns #{}
:has-wildcard? false
:mutation-commands #{}
:tables #{}
:table-wildcards #{}}))

(defn- alias-mapping
[^Table table]
Expand All @@ -49,18 +51,20 @@
(filter (complement alias?) table-names)))

(defn query->components
"Given a parsed query (i.e., a [subclass of] `Statement`) return a map with the `:tables` and `:columns` found within it.
"Given a parsed query (i.e., a [subclass of] `Statement`) return a map with the elements found within it.
(Specifically, it returns their fully-qualified names as strings, where 'fully-qualified' means 'as referred to in
the query'; this function doesn't do additional inference work to find out a table's schema.)"
[^Statement parsed-query]
(let [{:keys [columns has-wildcard?
mutation-commands
tables table-wildcards]} (query->raw-components parsed-query)
aliases (into {} (map alias-mapping tables))]
{:columns (into #{} (map #(.getColumnName ^Column %) columns))
:has-wildcard? has-wildcard?
:tables (into #{} (remove-aliases aliases (map #(.getName ^Table %) tables)))
:table-wildcards (into #{} (map (partial resolve-table-name aliases) table-wildcards))}))
aliases (into {} (map alias-mapping tables))]
{:columns (into #{} (map #(.getColumnName ^Column %) columns))
:has-wildcard? has-wildcard?
:mutation-commands (into #{} mutation-commands)
:tables (into #{} (remove-aliases aliases (map #(.getName ^Table %) tables)))
:table-wildcards (into #{} (map (partial resolve-table-name aliases) table-wildcards))}))

(defn parsed-query
"Main entry point: takes a string query and returns a `Statement` object that can be handled by the other functions."
Expand Down
1 change: 1 addition & 0 deletions src/macaw/walk.clj
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"keyword->key map for the AST-folding callbacks."
;; TODO: Move this to a Malli schema to simplify the indirection
{:column AstWalker$CallbackKey/COLUMN
:mutation AstWalker$CallbackKey/MUTATION_COMMAND
:table AstWalker$CallbackKey/TABLE
:table-wildcard AstWalker$CallbackKey/ALL_TABLE_COLUMNS
:wildcard AstWalker$CallbackKey/ALL_COLUMNS})
Expand Down
50 changes: 49 additions & 1 deletion test/macaw/core_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
(def components (comp m/query->components m/parsed-query))
(def columns (comp :columns components))
(def has-wildcard? (comp :has-wildcard? components))
(def mutations (comp :mutation-commands components))
(def tables (comp :tables components))
(def table-wcs (comp :table-wildcards components))

Expand All @@ -23,7 +24,7 @@
(is (= #{"core_user"}
(tables "select * from (select distinct email from core_user) q;")))))

(deftest issue-14-tables-with-complex-aliases-test
(deftest tables-with-complex-aliases-issue-14-test
(testing "With an alias that is also a table name"
#_(is (= #{"user" "user2_final"}
(tables
Expand All @@ -41,6 +42,53 @@
(is (= #{"id" "user_id"}
(columns "select id from orders group by user_id")))))

(deftest mutation-test
(is (= #{"alter-sequence"}
(mutations "alter sequence serial restart with 42")))
(is (= #{"alter-session"}
(mutations "alter session set foo = 'bar'")))
(is (= #{"alter-system"}
(mutations "alter system reset all")))
(is (= #{"alter-table"}
(mutations "alter table orders add column email text")))
(is (= #{"alter-view"}
(mutations "alter view foo as select bar;")))
(is (= #{"create-function"} ; Postgres syntax
(mutations "create function multiply(integer, integer) returns integer as 'select $1 * $2;' language sql
immutable returns null on null input;")))
(is (= #{"create-function"} ; Conventional syntax
(mutations "create function multiply(a integer, b integer) returns integer language sql immutable returns
null on null input return a + b;")))
(is (= #{"create-index"}
(mutations "create index idx_user_id on orders(user_id);")))
(is (= #{"create-schema"}
(mutations "create schema perthshire")))
(is (= #{"create-sequence"}
(mutations "create sequence users_seq start with 42 increment by 17")))
(is (= #{"create-synonym"}
(mutations "create synonym folk for people")))
(is (= #{"create-table"}
(mutations "create table poets (name text, id integer)")))
(is (= #{"create-view"}
(mutations "create view folk as select * from people where id > 10")))
(is (= #{"delete"}
(mutations "delete from people")))
(is (= #{"drop"}
(mutations "drop table people")))
(is (= #{"grant"}
(mutations "grant select, update, insert on people to myself")))
(is (= #{"insert"}
(mutations "insert into people(name,source) values ('Robert Fergusson', 'Twitter'), ('Robert Burns',
'Facebook')")))
(is (= #{"purge"}
(mutations "purge table people")))
(is (= #{"rename-table"}
(mutations "rename table people to folk")))
(is (= #{"truncate"}
(mutations "truncate table people")))
(is (= #{"update"}
(mutations "update people set name = 'Robert Fergusson' where id = 23"))))

(deftest alias-inclusion-test
(testing "Aliases are not included"
(is (= #{"orders" "foo"}
Expand Down

0 comments on commit 4a66e55

Please sign in to comment.