From 87efbe70165907d00db83ee469bf6b20695f07fc Mon Sep 17 00:00:00 2001 From: Nick Cabatoff Date: Sat, 18 May 2024 00:47:17 -0400 Subject: [PATCH] Add support for CTEs. (#37) * Add support for CTEs. Also fix a minor issue with interpreting the columns in an UPDATE statement that uses tables aliases. * Fix test failure --- pkg/vet/vet.go | 151 ++++++++++++++++++++++++++++++++------------ pkg/vet/vet_test.go | 27 +++++++- 2 files changed, 137 insertions(+), 41 deletions(-) diff --git a/pkg/vet/vet.go b/pkg/vet/vet.go index c5b9111..2d9a67a 100644 --- a/pkg/vet/vet.go +++ b/pkg/vet/vet.go @@ -277,7 +277,7 @@ func validateTableColumns(ctx VetContext, tables []TableUsed, cols []ColumnUsed) } } if !found { - if len(tables) == 1 { + if len(usedTables) == 1 { // to make error message more useful, if only one table is // referenced in the query, it's safe to assume user only // want to use columns from that table. @@ -541,6 +541,11 @@ func getUsedColumnsFromSortClause(sortList []*pg_query.Node) []ColumnUsed { func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) (queryParams []QueryParam, targetCols []schema.Column, err error) { usedCols := []ColumnUsed{} + if stmt.GetWithClause() != nil { + if err := parseCTE(ctx, stmt.GetWithClause()); err != nil { + return nil, nil, err + } + } postponed := PostponedNodes{} for _, fromClause := range stmt.FromClause { re := &ParseResult{} @@ -648,11 +653,22 @@ func validateSelectStmt(ctx VetContext, stmt *pg_query.SelectStmt) (queryParams return queryParams, targetCols, validateTableColumns(ctx, ctx.UsedTables, usedCols) } -func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam, error) { +func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam, []ColumnUsed, error) { + if stmt.GetWithClause() != nil { + if err := parseCTE(ctx, stmt.GetWithClause()); err != nil { + return nil, nil, err + } + } tableName := stmt.Relation.Relname if err := validateTable(ctx, tableName, true); err != nil { - return nil, err + return nil, nil, err } + + var tableAlias string + if stmt.Relation.Alias != nil { + tableAlias = stmt.Relation.Alias.Aliasname + } + usedTables := []TableUsed{{Name: tableName}} usedTables = append(usedTables, getUsedTablesFromSelectStmt(stmt.FromClause)...) @@ -682,7 +698,7 @@ func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam re := &ParseResult{} err := parseWhereClause(ctx, stmt.WhereClause, re) if err != nil { - return nil, err + return nil, nil, err } usedCols = append(usedCols, re.Columns...) AddQueryParams(&queryParams, re.Params) @@ -692,13 +708,25 @@ func validateUpdateStmt(ctx VetContext, stmt *pg_query.UpdateStmt) ([]QueryParam usedCols = append(usedCols, getUsedColumnsFromReturningList(stmt.ReturningList)...) } - return queryParams, validateTableColumns(ctx, usedTables, usedCols) + if len(usedCols) > 0 { + usedTables = append(usedTables, TableUsed{Name: tableName, Alias: tableAlias}) + if err := validateTableColumns(ctx, usedTables, usedCols); err != nil { + return nil, nil, err + } + } + + return queryParams, usedCols, nil } -func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam, error) { +func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam, []ColumnUsed, error) { + if stmt.GetWithClause() != nil { + if err := parseCTE(ctx, stmt.GetWithClause()); err != nil { + return nil, nil, err + } + } tableName := stmt.Relation.Relname if err := validateTable(ctx, tableName, true); err != nil { - return nil, err + return nil, nil, err } usedTables := []TableUsed{{Name: tableName}} @@ -736,7 +764,7 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam re := &ParseResult{} err := parseExpression(ctx, node, re) if err != nil { - return nil, fmt.Errorf("invalid value list: %w", err) + return nil, nil, fmt.Errorf("invalid value list: %w", err) } if len(re.Columns) > 0 { usedCols = append(usedCols, re.Columns...) @@ -758,7 +786,7 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam re := &ParseResult{} err := parseFromClause(ctx, fromClause, re) if err != nil { - return nil, err + return nil, nil, err } if len(re.Columns) > 0 { usedCols = append(usedCols, re.Columns...) @@ -772,7 +800,7 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam re := &ParseResult{} err := parseWhereClause(ctx, selectStmt.WhereClause, re) if err != nil { - return nil, err + return nil, nil, err } if len(re.Columns) > 0 { usedCols = append(usedCols, re.Columns...) @@ -796,12 +824,12 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam case target.GetSubLink() != nil: tv := target.GetSubLink().Subselect if tv.GetSelectStmt() == nil { - return nil, fmt.Errorf( + return nil, nil, fmt.Errorf( "unsupported subquery type in value list: %s", reflect.TypeOf(tv)) } qparams, _, err := validateSelectStmt(ctx, tv.GetSelectStmt()) if err != nil { - return nil, fmt.Errorf("invalid SELECT query in value list: %w", err) + return nil, nil, fmt.Errorf("invalid SELECT query in value list: %w", err) } if len(qparams) > 0 { AddQueryParams(&queryParams, qparams) @@ -815,17 +843,22 @@ func validateInsertStmt(ctx VetContext, stmt *pg_query.InsertStmt) ([]QueryParam } if err := validateTableColumns(ctx, usedTables, usedCols); err != nil { - return nil, err + return nil, nil, err } if err := validateInsertValues(ctx, targetCols, values); err != nil { - return nil, err + return nil, nil, err } - return queryParams, nil + return queryParams, usedCols, nil } -func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam, error) { +func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam, []ColumnUsed, error) { + if stmt.GetWithClause() != nil { + if err := parseCTE(ctx, stmt.GetWithClause()); err != nil { + return nil, nil, err + } + } tableName := stmt.Relation.Relname var tableAlias string @@ -834,7 +867,7 @@ func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam } if err := validateTable(ctx, tableName, true); err != nil { - return nil, err + return nil, nil, err } usedCols := []ColumnUsed{} @@ -845,25 +878,25 @@ func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam re := &ParseResult{} err := parseWhereClause(ctx, stmt.WhereClause, re) if err != nil { - return nil, err + return nil, nil, err } if len(re.Columns) > 0 { usedCols = append(usedCols, re.Columns...) } else { - return nil, fmt.Errorf("no columns in DELETE's WHERE clause") + return nil, nil, fmt.Errorf("no columns in DELETE's WHERE clause") } if len(re.Params) > 0 { queryParams = re.Params } } else { - return nil, fmt.Errorf("no WHERE clause for DELETE") + return nil, nil, fmt.Errorf("no WHERE clause for DELETE") } for _, using := range stmt.UsingClause { re := &ParseResult{} err := parseUsingClause(ctx, using, re) if err != nil { - return nil, err + return nil, nil, err } usedTables = append(usedTables, re.Tables...) } @@ -876,11 +909,39 @@ func validateDeleteStmt(ctx VetContext, stmt *pg_query.DeleteStmt) ([]QueryParam if len(usedCols) > 0 { usedTables = append(usedTables, TableUsed{Name: tableName, Alias: tableAlias}) if err := validateTableColumns(ctx, usedTables, usedCols); err != nil { - return nil, err + return nil, nil, err } } - return queryParams, nil + return queryParams, usedCols, nil +} + +func parseCTE(ctx VetContext, with *pg_query.WithClause) error { + for _, cteNode := range with.Ctes { + cte := cteNode.GetCommonTableExpr() + query := cte.GetCtequery() + _, cols, err := validateSqlQuery(ctx, query) + if err != nil { + return err + } + + var columns map[string]schema.Column + if cols != nil { + columns = make(map[string]schema.Column) + for _, col := range cols { + columns[col.Column] = schema.Column{ + Name: col.Column, + } + } + } + + ctx.InnerSchema.Tables[cte.Ctename] = schema.Table{ + Name: cte.Ctename, + Columns: columns, + ReadOnly: true, + } + } + return nil } func ValidateSqlQuery(ctx VetContext, queryStr string) ([]QueryParam, error) { @@ -893,26 +954,36 @@ func ValidateSqlQuery(ctx VetContext, queryStr string) ([]QueryParam, error) { return nil, fmt.Errorf("query contained more than one statement") } - var raw *pg_query.RawStmt = tree.Stmts[0] + params, _, err := validateSqlQuery(ctx, tree.Stmts[0].Stmt) + return params, err +} + +func validateSqlQuery(ctx VetContext, node *pg_query.Node) ([]QueryParam, []ColumnUsed, error) { + switch { - case raw.Stmt.GetSelectStmt() != nil: - qparams, _, err := validateSelectStmt(ctx, raw.Stmt.GetSelectStmt()) - return qparams, err - case raw.Stmt.GetUpdateStmt() != nil: - return validateUpdateStmt(ctx, raw.Stmt.GetUpdateStmt()) - case raw.Stmt.GetInsertStmt() != nil: - return validateInsertStmt(ctx, raw.Stmt.GetInsertStmt()) - case raw.Stmt.GetDeleteStmt() != nil: - return validateDeleteStmt(ctx, raw.Stmt.GetDeleteStmt()) - case raw.Stmt.GetDropStmt() != nil: - case raw.Stmt.GetTruncateStmt() != nil: - case raw.Stmt.GetAlterTableStmt() != nil: - case raw.Stmt.GetCreateSchemaStmt() != nil: - case raw.Stmt.GetVariableSetStmt() != nil: + case node.GetSelectStmt() != nil: + qparams, targetCols, err := validateSelectStmt(ctx, node.GetSelectStmt()) + var cused []ColumnUsed + for _, tcol := range targetCols { + cused = append(cused, ColumnUsed{Column: tcol.Name}) + } + return qparams, cused, err + case node.GetUpdateStmt() != nil: + return validateUpdateStmt(ctx, node.GetUpdateStmt()) + case node.GetInsertStmt() != nil: + return validateInsertStmt(ctx, node.GetInsertStmt()) + case node.GetDeleteStmt() != nil: + return validateDeleteStmt(ctx, node.GetDeleteStmt()) + case node.GetDropStmt() != nil: + case node.GetTruncateStmt() != nil: + case node.GetAlterTableStmt() != nil: + case node.GetCreateSchemaStmt() != nil: + case node.GetVariableSetStmt() != nil: + // TODO: check for invalid pg variables default: - return nil, fmt.Errorf("unsupported statement: %v", reflect.TypeOf(raw.Stmt)) + return nil, nil, fmt.Errorf("unsupported statement: %v", reflect.TypeOf(node)) } - return nil, nil + return nil, nil, nil } diff --git a/pkg/vet/vet_test.go b/pkg/vet/vet_test.go index 5c33edb..843736b 100644 --- a/pkg/vet/vet_test.go +++ b/pkg/vet/vet_test.go @@ -432,7 +432,7 @@ func TestSelect(t *testing.T) { `SELECT id, f.id, coalesce(bzz.created_at,0) FROM foo as f LEFT JOIN LATERAL ( - SELECT *, created_at, b.created_at, coalesce(baz_count,0), coalesce(baz_count,0) AS b_created_at + SELECT *, created_at, b.created_at, coalesce(baz_count,0), coalesce(baz_count,0) AS b_created_at FROM baz b ) bzz ON true WHERE value IS NULL`, @@ -447,6 +447,17 @@ func TestSelect(t *testing.T) { WHERE f.id = b.id) bzz ON true WHERE value IS NULL`, }, + { + "select CTE", + `WITH cte1 AS (SELECT id FROM foo) + SELECT id FROM cte1`, + }, + { + "select 2 CTEs", + `WITH cte1 AS (SELECT id FROM foo), + cte2 AS (SELECT value FROM foo) + SELECT c1.id, c2.value FROM cte1 c1, cte2 c2`, + }, } for _, tcase := range testCases { @@ -494,6 +505,15 @@ func TestUpdate(t *testing.T) { "update with returning", `UPDATE foo SET id=1 RETURNING value`, }, + { + "update alias with returning", + `UPDATE foo f SET id=1 RETURNING f.value`, + }, + { + "update CTE", + `WITH cte1 AS (SELECT id FROM foo) + UPDATE foo SET value='bar' FROM cte1 WHERE foo.id = cte1.id`, + }, } for _, tcase := range testCases { @@ -590,6 +610,11 @@ func TestDelete(t *testing.T) { "delete using with aliases", `DELETE FROM foo AS f USING bar b WHERE f.id = b.id`, }, + { + "delete CTE", + `WITH cte1 AS (SELECT id FROM foo) + DELETE FROM foo USING cte1 WHERE foo.id = cte1.id`, + }, } for _, tcase := range testCases {