Skip to content

Commit

Permalink
Add support for CTEs. (#37)
Browse files Browse the repository at this point in the history
* Add support for CTEs.  Also fix a minor issue with interpreting the columns in an UPDATE statement that uses tables aliases.

* Fix test failure
  • Loading branch information
ncabatoff authored May 18, 2024
1 parent 8ff20b7 commit 87efbe7
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 41 deletions.
151 changes: 111 additions & 40 deletions pkg/vet/vet.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ func validateTableColumns(ctx VetContext, tables []TableUsed, cols []ColumnUsed)
}
}
if !found {
if len(tables) == 1 {
if len(usedTables) == 1 {
// to make error message more useful, if only one table is
// referenced in the query, it's safe to assume user only
// want to use columns from that table.
Expand Down Expand Up @@ -541,6 +541,11 @@ func getUsedColumnsFromSortClause(sortList []*pg_query.Node) []ColumnUsed {
func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) (queryParams []QueryParam, targetCols []schema.Column, err error) {
usedCols := []ColumnUsed{}

if stmt.GetWithClause() != nil {
if err := parseCTE(ctx, stmt.GetWithClause()); err != nil {
return nil, nil, err
}
}
postponed := PostponedNodes{}
for _, fromClause := range stmt.FromClause {
re := &ParseResult{}
Expand Down Expand Up @@ -648,11 +653,22 @@ func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) (queryParams
return queryParams, targetCols, validateTableColumns(ctx, ctx.UsedTables, usedCols)
}

func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam, error) {
func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam, []ColumnUsed, error) {
if stmt.GetWithClause() != nil {
if err := parseCTE(ctx, stmt.GetWithClause()); err != nil {
return nil, nil, err
}
}
tableName := stmt.Relation.Relname
if err := validateTable(ctx, tableName, true); err != nil {
return nil, err
return nil, nil, err
}

var tableAlias string
if stmt.Relation.Alias != nil {
tableAlias = stmt.Relation.Alias.Aliasname
}

usedTables := []TableUsed{{Name: tableName}}
usedTables = append(usedTables, getUsedTablesFromSelectStmt(stmt.FromClause)...)

Expand Down Expand Up @@ -682,7 +698,7 @@ func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam
re := &ParseResult{}
err := parseWhereClause(ctx, stmt.WhereClause, re)
if err != nil {
return nil, err
return nil, nil, err
}
usedCols = append(usedCols, re.Columns...)
AddQueryParams(&queryParams, re.Params)
Expand All @@ -692,13 +708,25 @@ func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam
usedCols = append(usedCols, getUsedColumnsFromReturningList(stmt.ReturningList)...)
}

return queryParams, validateTableColumns(ctx, usedTables, usedCols)
if len(usedCols) > 0 {
usedTables = append(usedTables, TableUsed{Name: tableName, Alias: tableAlias})
if err := validateTableColumns(ctx, usedTables, usedCols); err != nil {
return nil, nil, err
}
}

return queryParams, usedCols, nil
}

func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam, error) {
func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam, []ColumnUsed, error) {
if stmt.GetWithClause() != nil {
if err := parseCTE(ctx, stmt.GetWithClause()); err != nil {
return nil, nil, err
}
}
tableName := stmt.Relation.Relname
if err := validateTable(ctx, tableName, true); err != nil {
return nil, err
return nil, nil, err
}
usedTables := []TableUsed{{Name: tableName}}

Expand Down Expand Up @@ -736,7 +764,7 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam
re := &ParseResult{}
err := parseExpression(ctx, node, re)
if err != nil {
return nil, fmt.Errorf("invalid value list: %w", err)
return nil, nil, fmt.Errorf("invalid value list: %w", err)
}
if len(re.Columns) > 0 {
usedCols = append(usedCols, re.Columns...)
Expand All @@ -758,7 +786,7 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam
re := &ParseResult{}
err := parseFromClause(ctx, fromClause, re)
if err != nil {
return nil, err
return nil, nil, err
}
if len(re.Columns) > 0 {
usedCols = append(usedCols, re.Columns...)
Expand All @@ -772,7 +800,7 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam
re := &ParseResult{}
err := parseWhereClause(ctx, selectStmt.WhereClause, re)
if err != nil {
return nil, err
return nil, nil, err
}
if len(re.Columns) > 0 {
usedCols = append(usedCols, re.Columns...)
Expand All @@ -796,12 +824,12 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam
case target.GetSubLink() != nil:
tv := target.GetSubLink().Subselect
if tv.GetSelectStmt() == nil {
return nil, fmt.Errorf(
return nil, nil, fmt.Errorf(
"unsupported subquery type in value list: %s", reflect.TypeOf(tv))
}
qparams, _, err := validateSelectStmt(ctx, tv.GetSelectStmt())
if err != nil {
return nil, fmt.Errorf("invalid SELECT query in value list: %w", err)
return nil, nil, fmt.Errorf("invalid SELECT query in value list: %w", err)
}
if len(qparams) > 0 {
AddQueryParams(&queryParams, qparams)
Expand All @@ -815,17 +843,22 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam
}

if err := validateTableColumns(ctx, usedTables, usedCols); err != nil {
return nil, err
return nil, nil, err
}

if err := validateInsertValues(ctx, targetCols, values); err != nil {
return nil, err
return nil, nil, err
}

return queryParams, nil
return queryParams, usedCols, nil
}

func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam, error) {
func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam, []ColumnUsed, error) {
if stmt.GetWithClause() != nil {
if err := parseCTE(ctx, stmt.GetWithClause()); err != nil {
return nil, nil, err
}
}
tableName := stmt.Relation.Relname
var tableAlias string

Expand All @@ -834,7 +867,7 @@ func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam
}

if err := validateTable(ctx, tableName, true); err != nil {
return nil, err
return nil, nil, err
}

usedCols := []ColumnUsed{}
Expand All @@ -845,25 +878,25 @@ func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam
re := &ParseResult{}
err := parseWhereClause(ctx, stmt.WhereClause, re)
if err != nil {
return nil, err
return nil, nil, err
}
if len(re.Columns) > 0 {
usedCols = append(usedCols, re.Columns...)
} else {
return nil, fmt.Errorf("no columns in DELETE's WHERE clause")
return nil, nil, fmt.Errorf("no columns in DELETE's WHERE clause")
}
if len(re.Params) > 0 {
queryParams = re.Params
}
} else {
return nil, fmt.Errorf("no WHERE clause for DELETE")
return nil, nil, fmt.Errorf("no WHERE clause for DELETE")
}

for _, using := range stmt.UsingClause {
re := &ParseResult{}
err := parseUsingClause(ctx, using, re)
if err != nil {
return nil, err
return nil, nil, err
}
usedTables = append(usedTables, re.Tables...)
}
Expand All @@ -876,11 +909,39 @@ func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam
if len(usedCols) > 0 {
usedTables = append(usedTables, TableUsed{Name: tableName, Alias: tableAlias})
if err := validateTableColumns(ctx, usedTables, usedCols); err != nil {
return nil, err
return nil, nil, err
}
}

return queryParams, nil
return queryParams, usedCols, nil
}

func parseCTE(ctx VetContext, with *pg_query.WithClause) error {
for _, cteNode := range with.Ctes {
cte := cteNode.GetCommonTableExpr()
query := cte.GetCtequery()
_, cols, err := validateSqlQuery(ctx, query)
if err != nil {
return err
}

var columns map[string]schema.Column
if cols != nil {
columns = make(map[string]schema.Column)
for _, col := range cols {
columns[col.Column] = schema.Column{
Name: col.Column,
}
}
}

ctx.InnerSchema.Tables[cte.Ctename] = schema.Table{
Name: cte.Ctename,
Columns: columns,
ReadOnly: true,
}
}
return nil
}

func ValidateSqlQuery(ctx VetContext, queryStr string) ([]QueryParam, error) {
Expand All @@ -893,26 +954,36 @@ func ValidateSqlQuery(ctx VetContext, queryStr string) ([]QueryParam, error) {
return nil, fmt.Errorf("query contained more than one statement")
}

var raw *pg_query.RawStmt = tree.Stmts[0]
params, _, err := validateSqlQuery(ctx, tree.Stmts[0].Stmt)
return params, err
}

func validateSqlQuery(ctx VetContext, node *pg_query.Node) ([]QueryParam, []ColumnUsed, error) {

switch {
case raw.Stmt.GetSelectStmt() != nil:
qparams, _, err := validateSelectStmt(ctx, raw.Stmt.GetSelectStmt())
return qparams, err
case raw.Stmt.GetUpdateStmt() != nil:
return validateUpdateStmt(ctx, raw.Stmt.GetUpdateStmt())
case raw.Stmt.GetInsertStmt() != nil:
return validateInsertStmt(ctx, raw.Stmt.GetInsertStmt())
case raw.Stmt.GetDeleteStmt() != nil:
return validateDeleteStmt(ctx, raw.Stmt.GetDeleteStmt())
case raw.Stmt.GetDropStmt() != nil:
case raw.Stmt.GetTruncateStmt() != nil:
case raw.Stmt.GetAlterTableStmt() != nil:
case raw.Stmt.GetCreateSchemaStmt() != nil:
case raw.Stmt.GetVariableSetStmt() != nil:
case node.GetSelectStmt() != nil:
qparams, targetCols, err := validateSelectStmt(ctx, node.GetSelectStmt())
var cused []ColumnUsed
for _, tcol := range targetCols {
cused = append(cused, ColumnUsed{Column: tcol.Name})
}
return qparams, cused, err
case node.GetUpdateStmt() != nil:
return validateUpdateStmt(ctx, node.GetUpdateStmt())
case node.GetInsertStmt() != nil:
return validateInsertStmt(ctx, node.GetInsertStmt())
case node.GetDeleteStmt() != nil:
return validateDeleteStmt(ctx, node.GetDeleteStmt())
case node.GetDropStmt() != nil:
case node.GetTruncateStmt() != nil:
case node.GetAlterTableStmt() != nil:
case node.GetCreateSchemaStmt() != nil:
case node.GetVariableSetStmt() != nil:

// TODO: check for invalid pg variables
default:
return nil, fmt.Errorf("unsupported statement: %v", reflect.TypeOf(raw.Stmt))
return nil, nil, fmt.Errorf("unsupported statement: %v", reflect.TypeOf(node))
}

return nil, nil
return nil, nil, nil
}
27 changes: 26 additions & 1 deletion pkg/vet/vet_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,7 @@ func TestSelect(t *testing.T) {
`SELECT id, f.id, coalesce(bzz.created_at,0)
FROM foo as f
LEFT JOIN LATERAL (
SELECT *, created_at, b.created_at, coalesce(baz_count,0), coalesce(baz_count,0) AS b_created_at
SELECT *, created_at, b.created_at, coalesce(baz_count,0), coalesce(baz_count,0) AS b_created_at
FROM baz b
) bzz ON true
WHERE value IS NULL`,
Expand All @@ -447,6 +447,17 @@ func TestSelect(t *testing.T) {
WHERE f.id = b.id) bzz ON true
WHERE value IS NULL`,
},
{
"select CTE",
`WITH cte1 AS (SELECT id FROM foo)
SELECT id FROM cte1`,
},
{
"select 2 CTEs",
`WITH cte1 AS (SELECT id FROM foo),
cte2 AS (SELECT value FROM foo)
SELECT c1.id, c2.value FROM cte1 c1, cte2 c2`,
},
}

for _, tcase := range testCases {
Expand Down Expand Up @@ -494,6 +505,15 @@ func TestUpdate(t *testing.T) {
"update with returning",
`UPDATE foo SET id=1 RETURNING value`,
},
{
"update alias with returning",
`UPDATE foo f SET id=1 RETURNING f.value`,
},
{
"update CTE",
`WITH cte1 AS (SELECT id FROM foo)
UPDATE foo SET value='bar' FROM cte1 WHERE foo.id = cte1.id`,
},
}

for _, tcase := range testCases {
Expand Down Expand Up @@ -590,6 +610,11 @@ func TestDelete(t *testing.T) {
"delete using with aliases",
`DELETE FROM foo AS f USING bar b WHERE f.id = b.id`,
},
{
"delete CTE",
`WITH cte1 AS (SELECT id FROM foo)
DELETE FROM foo USING cte1 WHERE foo.id = cte1.id`,
},
}

for _, tcase := range testCases {
Expand Down

0 comments on commit 87efbe7

Please sign in to comment.