diff --git a/go/vt/graph/graph.go b/go/vt/graph/graph.go index 54668027008..cc5f837d6f7 100644 --- a/go/vt/graph/graph.go +++ b/go/vt/graph/graph.go @@ -18,19 +18,28 @@ package graph import ( "fmt" + "maps" "slices" "strings" ) +const ( + white int = iota + grey + black +) + // Graph is a generic graph implementation. type Graph[C comparable] struct { - edges map[C][]C + edges map[C][]C + orderedVertices []C } // NewGraph creates a new graph for the given comparable type. func NewGraph[C comparable]() *Graph[C] { return &Graph[C]{ - edges: map[C][]C{}, + edges: map[C][]C{}, + orderedVertices: []C{}, } } @@ -41,6 +50,7 @@ func (gr *Graph[C]) AddVertex(vertex C) { return } gr.edges[vertex] = []C{} + gr.orderedVertices = append(gr.orderedVertices, vertex) } // AddEdge adds an edge to the given Graph. @@ -85,8 +95,8 @@ func (gr *Graph[C]) HasCycles() bool { color := map[C]int{} for vertex := range gr.edges { // If any vertex is still white, we initiate a new DFS. - if color[vertex] == 0 { - if gr.hasCyclesDfs(color, vertex) { + if color[vertex] == white { + if hasCycle, _ := gr.hasCyclesDfs(color, vertex); hasCycle { return true } } @@ -94,26 +104,56 @@ func (gr *Graph[C]) HasCycles() bool { return false } +// GetCycles returns all known cycles in the graph. +// It returns a map of vertices to the cycle they are part of. +// We are using a well-known DFS based colouring algorithm to check for cycles. +// Look at https://cp-algorithms.com/graph/finding-cycle.html for more details on the algorithm. +func (gr *Graph[C]) GetCycles() (vertices map[C][]C) { + // If the graph is empty, then we don't need to check anything. + if gr.Empty() { + return nil + } + vertices = make(map[C][]C) + // Initialize the coloring map. + // 0 represents white. + // 1 represents grey. + // 2 represents black. + color := map[C]int{} + for _, vertex := range gr.orderedVertices { + // If any vertex is still white, we initiate a new DFS. + if color[vertex] == white { + // We clone the colors because we wnt full coverage for all vertices. + // Otherwise, the algorithm is optimal and stop more-or-less after the first cycle. + color := maps.Clone(color) + if hasCycle, cycle := gr.hasCyclesDfs(color, vertex); hasCycle { + vertices[vertex] = cycle + } + } + } + return vertices +} + // hasCyclesDfs is a utility function for checking for cycles in a graph. // It runs a dfs from the given vertex marking each vertex as grey. During the dfs, // if we encounter a grey vertex, we know we have a cycle. We mark the visited vertices black // on finishing the dfs. -func (gr *Graph[C]) hasCyclesDfs(color map[C]int, vertex C) bool { +func (gr *Graph[C]) hasCyclesDfs(color map[C]int, vertex C) (bool, []C) { // Mark the vertex grey. - color[vertex] = 1 + color[vertex] = grey + result := []C{vertex} // Go over all the edges. for _, end := range gr.edges[vertex] { // If we encounter a white vertex, we continue the dfs. - if color[end] == 0 { - if gr.hasCyclesDfs(color, end) { - return true + if color[end] == white { + if hasCycle, cycle := gr.hasCyclesDfs(color, end); hasCycle { + return true, append(result, cycle...) } - } else if color[end] == 1 { + } else if color[end] == grey { // We encountered a grey vertex, we have a cycle. - return true + return true, append(result, end) } } // Mark the vertex black before finishing - color[vertex] = 2 - return false + color[vertex] = black + return false, nil } diff --git a/go/vt/graph/graph_test.go b/go/vt/graph/graph_test.go index bc334c7d225..3231998039e 100644 --- a/go/vt/graph/graph_test.go +++ b/go/vt/graph/graph_test.go @@ -95,6 +95,7 @@ func TestStringGraph(t *testing.T) { wantedGraph string wantEmpty bool wantHasCycles bool + wantCycles map[string][]string }{ { name: "empty graph", @@ -137,6 +138,13 @@ E - F F - A`, wantEmpty: false, wantHasCycles: true, + wantCycles: map[string][]string{ + "A": {"A", "B", "E", "F", "A"}, + "B": {"B", "E", "F", "A", "B"}, + "D": {"D", "E", "F", "A", "B", "E"}, + "E": {"E", "F", "A", "B", "E"}, + "F": {"F", "A", "B", "E", "F"}, + }, }, } for _, tt := range testcases { @@ -148,6 +156,14 @@ F - A`, require.Equal(t, tt.wantedGraph, graph.PrintGraph()) require.Equal(t, tt.wantEmpty, graph.Empty()) require.Equal(t, tt.wantHasCycles, graph.HasCycles()) + if tt.wantCycles == nil { + tt.wantCycles = map[string][]string{} + } + actualCycles := graph.GetCycles() + if actualCycles == nil { + actualCycles = map[string][]string{} + } + require.Equal(t, tt.wantCycles, actualCycles) }) } } diff --git a/go/vt/schemadiff/diff_test.go b/go/vt/schemadiff/diff_test.go index fbe7238e3fd..3fe94e3b0b5 100644 --- a/go/vt/schemadiff/diff_test.go +++ b/go/vt/schemadiff/diff_test.go @@ -313,7 +313,7 @@ func TestDiffTables(t *testing.T) { for _, ts := range tt { t.Run(ts.name, func(t *testing.T) { var fromCreateTable *sqlparser.CreateTable - hints := &DiffHints{} + hints := EmptyDiffHints() if ts.hints != nil { hints = ts.hints } @@ -448,7 +448,7 @@ func TestDiffViews(t *testing.T) { name: "none", }, } - hints := &DiffHints{} + hints := EmptyDiffHints() env := NewTestEnv() for _, ts := range tt { t.Run(ts.name, func(t *testing.T) { @@ -545,6 +545,7 @@ func TestDiffSchemas(t *testing.T) { cdiffs []string expectError string tableRename int + fkStrategy int }{ { name: "identical tables", @@ -799,6 +800,45 @@ func TestDiffSchemas(t *testing.T) { "CREATE TABLE `t5` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f5` (`i`),\n\tCONSTRAINT `f5` FOREIGN KEY (`i`) REFERENCES `t7` (`id`)\n)", }, }, + { + name: "create tables with foreign keys, with invalid fk reference", + from: "create table t (id int primary key)", + to: ` + create table t (id int primary key); + create table t11 (id int primary key, i int, constraint f1101a foreign key (i) references t12 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1201a foreign key (i) references t9 (id) on delete set null); + `, + expectError: "table `t12` foreign key references nonexistent table `t9`", + }, + { + name: "create tables with foreign keys, with invalid fk reference", + from: "create table t (id int primary key)", + to: ` + create table t (id int primary key); + create table t11 (id int primary key, i int, constraint f1101b foreign key (i) references t12 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1201b foreign key (i) references t9 (id) on delete set null); + `, + expectError: "table `t12` foreign key references nonexistent table `t9`", + fkStrategy: ForeignKeyCheckStrategyIgnore, + }, + { + name: "create tables with foreign keys, with valid cycle", + from: "create table t (id int primary key)", + to: ` + create table t (id int primary key); + create table t11 (id int primary key, i int, constraint f1101c foreign key (i) references t12 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1201c foreign key (i) references t11 (id) on delete set null); + `, + diffs: []string{ + "create table t11 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f1101c (i),\n\tconstraint f1101c foreign key (i) references t12 (id) on delete restrict\n)", + "create table t12 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f1201c (i),\n\tconstraint f1201c foreign key (i) references t11 (id) on delete set null\n)", + }, + cdiffs: []string{ + "CREATE TABLE `t11` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f1101c` (`i`),\n\tCONSTRAINT `f1101c` FOREIGN KEY (`i`) REFERENCES `t12` (`id`) ON DELETE RESTRICT\n)", + "CREATE TABLE `t12` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f1201c` (`i`),\n\tCONSTRAINT `f1201c` FOREIGN KEY (`i`) REFERENCES `t11` (`id`) ON DELETE SET NULL\n)", + }, + fkStrategy: ForeignKeyCheckStrategyIgnore, + }, { name: "drop tables with foreign keys, expect specific order", from: "create table t7(id int primary key); create table t5 (id int primary key, i int, constraint f5 foreign key (i) references t7(id)); create table t4 (id int primary key, i int, constraint f4 foreign key (i) references t7(id));", @@ -932,14 +972,15 @@ func TestDiffSchemas(t *testing.T) { for _, ts := range tt { t.Run(ts.name, func(t *testing.T) { hints := &DiffHints{ - TableRenameStrategy: ts.tableRename, + TableRenameStrategy: ts.tableRename, + ForeignKeyCheckStrategy: ts.fkStrategy, } diff, err := DiffSchemasSQL(env, ts.from, ts.to, hints) if ts.expectError != "" { require.Error(t, err) assert.Contains(t, err.Error(), ts.expectError) } else { - assert.NoError(t, err) + require.NoError(t, err) diffs, err := diff.OrderedDiffs(ctx) assert.NoError(t, err) @@ -1024,7 +1065,7 @@ func TestSchemaApplyError(t *testing.T) { to: "create table t(id int); create view v1 as select * from t; create view v2 as select * from t", }, } - hints := &DiffHints{} + hints := EmptyDiffHints() env := NewTestEnv() for _, ts := range tt { t.Run(ts.name, func(t *testing.T) { diff --git a/go/vt/schemadiff/errors.go b/go/vt/schemadiff/errors.go index 5268db76ff3..dc73acdb9a0 100644 --- a/go/vt/schemadiff/errors.go +++ b/go/vt/schemadiff/errors.go @@ -288,16 +288,31 @@ func (e *ForeignKeyDependencyUnresolvedError) Error() string { type ForeignKeyLoopError struct { Table string - Loop []string + Loop []*ForeignKeyTableColumns } func (e *ForeignKeyLoopError) Error() string { tableIsInsideLoop := false - escaped := make([]string, len(e.Loop)) - for i, t := range e.Loop { - escaped[i] = sqlescape.EscapeID(t) - if t == e.Table { + loop := e.Loop + // The tables in the loop could be e.g.: + // t1->t2->a->b->c->a + // In such case, the loop is a->b->c->a. The last item is always the head & tail of the loop. + // We want to distinguish between the case where the table is inside the loop and the case where it's outside, + // so we remove the prefix of the loop that doesn't participate in the actual cycle. + if len(loop) > 0 { + last := loop[len(loop)-1] + for i := range loop { + if loop[i].Table == last.Table { + loop = loop[i:] + break + } + } + } + escaped := make([]string, len(loop)) + for i, fk := range loop { + escaped[i] = fk.Escaped() + if fk.Table == e.Table { tableIsInsideLoop = true } } diff --git a/go/vt/schemadiff/schema.go b/go/vt/schemadiff/schema.go index e3782fdbf0b..8081c6eaeea 100644 --- a/go/vt/schemadiff/schema.go +++ b/go/vt/schemadiff/schema.go @@ -18,10 +18,14 @@ package schemadiff import ( "errors" + "slices" "sort" "strings" + "golang.org/x/exp/maps" + "vitess.io/vitess/go/mysql/capabilities" + "vitess.io/vitess/go/vt/graph" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -37,7 +41,6 @@ type Schema struct { foreignKeyParents []*CreateTableEntity // subset of tables foreignKeyChildren []*CreateTableEntity // subset of tables - foreignKeyLoopMap map[string][]string // map of table name that either participate, or directly or indirectly reference foreign key loops env *Environment } @@ -52,7 +55,6 @@ func newEmptySchema(env *Environment) *Schema { foreignKeyParents: []*CreateTableEntity{}, foreignKeyChildren: []*CreateTableEntity{}, - foreignKeyLoopMap: map[string][]string{}, env: env, } @@ -72,7 +74,7 @@ func NewSchemaFromEntities(env *Environment, entities []Entity) (*Schema, error) return nil, &UnsupportedEntityError{Entity: c.Name(), Statement: c.Create().CanonicalStatementString()} } } - err := schema.normalize() + err := schema.normalize(EmptyDiffHints()) return schema, err } @@ -135,42 +137,6 @@ func getForeignKeyParentTableNames(createTable *sqlparser.CreateTable) (names [] return names } -// findForeignKeyLoop is a stateful recursive function that determines whether a given table participates in a foreign -// key loop or derives from one. It returns a list of table names that form a loop, or nil if no loop is found. -// The function updates and checks the stateful map s.foreignKeyLoopMap to avoid re-analyzing the same table twice. -func (s *Schema) findForeignKeyLoop(tableName string, seen []string) (loop []string) { - if loop := s.foreignKeyLoopMap[tableName]; loop != nil { - return loop - } - t := s.Table(tableName) - if t == nil { - return nil - } - seen = append(seen, tableName) - for i, seenTable := range seen { - if i == len(seen)-1 { - // as we've just appended the table name to the end of the slice, we should skip it. - break - } - if seenTable == tableName { - // This table alreay appears in `seen`. - // We only return the suffix of `seen` that starts (and now ends) with this table. - return seen[i:] - } - } - for _, referencedTableName := range getForeignKeyParentTableNames(t.CreateTable) { - if loop := s.findForeignKeyLoop(referencedTableName, seen); loop != nil { - // Found loop. Update cache. - // It's possible for one table to participate in more than one foreign key loop, but - // we suffice with one loop, since we already only ever report one foreign key error - // per table. - s.foreignKeyLoopMap[tableName] = loop - return loop - } - } - return nil -} - // getViewDependentTableNames analyzes a CREATE VIEW definition and extracts all tables/views read by this view func getViewDependentTableNames(createView *sqlparser.CreateView) (names []string) { _ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) { @@ -191,7 +157,7 @@ func getViewDependentTableNames(createView *sqlparser.CreateView) (names []strin // normalize is called as part of Schema creation process. The user may only get a hold of normalized schema. // It validates some cross-entity constraints, and orders entity based on dependencies (e.g. tables, views that read from tables, 2nd level views, etc.) -func (s *Schema) normalize() error { +func (s *Schema) normalize(hints *DiffHints) error { var errs error s.named = make(map[string]Entity, len(s.tables)+len(s.views)) @@ -284,8 +250,10 @@ func (s *Schema) normalize() error { } referencedEntity, ok := s.named[referencedTableName] if !ok { - errs = errors.Join(errs, addEntityFkError(t, &ForeignKeyNonexistentReferencedTableError{Table: name, ReferencedTable: referencedTableName})) - continue + if hints.ForeignKeyCheckStrategy == ForeignKeyCheckStrategyStrict { + errs = errors.Join(errs, addEntityFkError(t, &ForeignKeyNonexistentReferencedTableError{Table: name, ReferencedTable: referencedTableName})) + continue + } } if _, ok := referencedEntity.(*CreateViewEntity); ok { errs = errors.Join(errs, addEntityFkError(t, &ForeignKeyReferencesViewError{Table: name, ReferencedView: referencedTableName})) @@ -310,6 +278,76 @@ func (s *Schema) normalize() error { s.foreignKeyParents = append(s.foreignKeyParents, t) } } + if len(dependencyLevels) != len(s.tables) { + // We have leftover tables. This can happen if there's foreign key loops + for _, t := range s.tables { + if _, ok := dependencyLevels[t.Name()]; ok { + // known table + continue + } + // Table is part of a loop or references a loop + s.sorted = append(s.sorted, t) + dependencyLevels[t.Name()] = iterationLevel // all in same level + } + + // Now, let's see if the loop is valid or invalid. For example: + // users.avatar_id -> avatars.id + // avatars.creator_id -> users.id + // is a valid loop, because even though the two tables reference each other, the loop ends in different columns. + type tableCol struct { + tableName sqlparser.TableName + colNames sqlparser.Columns + } + var tableColHash = func(tc tableCol) string { + res := sqlparser.String(tc.tableName) + for _, colName := range tc.colNames { + res += "|" + sqlparser.String(colName) + } + return res + } + var decodeTableColHash = func(hash string) *ForeignKeyTableColumns { + tokens := strings.Split(hash, "|") + return &ForeignKeyTableColumns{tokens[0], tokens[1:]} + } + g := graph.NewGraph[string]() + for _, table := range s.tables { + for _, cfk := range table.TableSpec.Constraints { + check, ok := cfk.Details.(*sqlparser.ForeignKeyDefinition) + if !ok { + // Not a foreign key + continue + } + + parentVertex := tableCol{ + tableName: check.ReferenceDefinition.ReferencedTable, + colNames: check.ReferenceDefinition.ReferencedColumns, + } + childVertex := tableCol{ + tableName: table.Table, + colNames: check.Source, + } + g.AddEdge(tableColHash(parentVertex), tableColHash(childVertex)) + } + } + cycles := g.GetCycles() // map of table name to cycle + // golang maps have undefined iteration order. For consistent output, we sort the keys. + vertices := maps.Keys(cycles) + slices.Sort(vertices) + for _, vertex := range vertices { + cycle := cycles[vertex] + if len(cycle) == 0 { + continue + } + cycleTables := make([]*ForeignKeyTableColumns, len(cycle)) + for i := range cycle { + // Reduce tablename|colname(s) to just tablename + cycleTables[i] = decodeTableColHash(cycle[i]) + } + tableName := cycleTables[0].Table + errs = errors.Join(errs, addEntityFkError(s.named[tableName], &ForeignKeyLoopError{Table: tableName, Loop: cycleTables})) + } + } + // We now iterate all views. We iterate "dependency levels": // - first we want all views that only depend on tables. These are 1st level views. // - then we only want views that depend on 1st level views or on tables. These are 2nd level views. @@ -347,14 +385,6 @@ func (s *Schema) normalize() error { } if len(s.sorted) != len(s.tables)+len(s.views) { - - for _, t := range s.tables { - if _, ok := dependencyLevels[t.Name()]; !ok { - if loop := s.findForeignKeyLoop(t.Name(), nil); loop != nil { - errs = errors.Join(errs, addEntityFkError(t, &ForeignKeyLoopError{Table: t.Name(), Loop: loop})) - } - } - } // We have leftover tables or views. This can happen if the schema definition is invalid: // - a table's foreign key references a nonexistent table // - two or more tables have circular FK dependency @@ -724,7 +754,7 @@ func (s *Schema) copy() *Schema { // apply attempts to apply given list of diffs to this object. // These diffs are CREATE/DROP/ALTER TABLE/VIEW. -func (s *Schema) apply(diffs []EntityDiff) error { +func (s *Schema) apply(diffs []EntityDiff, hints *DiffHints) error { for _, diff := range diffs { switch diff := diff.(type) { case *CreateTableEntityDiff: @@ -834,7 +864,7 @@ func (s *Schema) apply(diffs []EntityDiff) error { return &UnsupportedApplyOperationError{Statement: diff.CanonicalStatementString()} } } - if err := s.normalize(); err != nil { + if err := s.normalize(hints); err != nil { return err } return nil @@ -845,7 +875,7 @@ func (s *Schema) apply(diffs []EntityDiff) error { // The operation does not modify this object. Instead, if successful, a new (modified) Schema is returned. func (s *Schema) Apply(diffs []EntityDiff) (*Schema, error) { dup := s.copy() - if err := dup.apply(diffs); err != nil { + if err := dup.apply(diffs, EmptyDiffHints()); err != nil { return nil, err } return dup, nil @@ -861,7 +891,7 @@ func (s *Schema) SchemaDiff(other *Schema, hints *DiffHints) (*SchemaDiff, error if err != nil { return nil, err } - schemaDiff := NewSchemaDiff(s) + schemaDiff := NewSchemaDiff(s, hints) schemaDiff.loadDiffs(diffs) // Utility function to see whether the given diff has dependencies on diffs that operate on any of the given named entities, diff --git a/go/vt/schemadiff/schema_diff.go b/go/vt/schemadiff/schema_diff.go index d2f5e012220..3fbc1e6c9d3 100644 --- a/go/vt/schemadiff/schema_diff.go +++ b/go/vt/schemadiff/schema_diff.go @@ -165,6 +165,7 @@ func permDiff(ctx context.Context, a []EntityDiff, callback func([]EntityDiff) ( // Operations on SchemaDiff are not concurrency-safe. type SchemaDiff struct { schema *Schema + hints *DiffHints diffs []EntityDiff diffMap map[string]EntityDiff // key is diff's CanonicalStatementString() @@ -173,9 +174,10 @@ type SchemaDiff struct { r *mathutil.EquivalenceRelation // internal structure to help determine diffs } -func NewSchemaDiff(schema *Schema) *SchemaDiff { +func NewSchemaDiff(schema *Schema, hints *DiffHints) *SchemaDiff { return &SchemaDiff{ schema: schema, + hints: hints, dependencies: make(map[string]*DiffDependency), diffMap: make(map[string]EntityDiff), r: mathutil.NewEquivalenceRelation(), @@ -318,7 +320,7 @@ func (d *SchemaDiff) OrderedDiffs(ctx context.Context) ([]EntityDiff, error) { // We want to apply the changes one by one, and validate the schema after each change for i := range permutatedDiffs { // apply inline - if err := permutationSchema.apply(permutatedDiffs[i : i+1]); err != nil { + if err := permutationSchema.apply(permutatedDiffs[i:i+1], d.hints); err != nil { // permutation is invalid return false // continue searching } @@ -341,6 +343,18 @@ func (d *SchemaDiff) OrderedDiffs(ctx context.Context) ([]EntityDiff, error) { // Done taking care of this equivalence class. } + if d.hints.ForeignKeyCheckStrategy != ForeignKeyCheckStrategyStrict { + // We may have allowed invalid foreign key dependencies along the way. But we must then validate the final schema + // to ensure that all foreign keys are valid. + hints := *d.hints + hints.ForeignKeyCheckStrategy = ForeignKeyCheckStrategyStrict + if err := lastGoodSchema.normalize(&hints); err != nil { + return nil, &ImpossibleApplyDiffOrderError{ + UnorderedDiffs: d.UnorderedDiffs(), + ConflictingDiffs: d.UnorderedDiffs(), + } + } + } return orderedDiffs, nil } diff --git a/go/vt/schemadiff/schema_diff_test.go b/go/vt/schemadiff/schema_diff_test.go index 4fbc31a6492..5aff4a0b408 100644 --- a/go/vt/schemadiff/schema_diff_test.go +++ b/go/vt/schemadiff/schema_diff_test.go @@ -272,6 +272,9 @@ func TestSchemaDiff(t *testing.T) { entityOrder []string // names of tables/views in expected diff order mysqlServerVersion string instantCapability InstantDDLCapability + fkStrategy int + expectError string + expectOrderedError string }{ { name: "no change", @@ -624,6 +627,33 @@ func TestSchemaDiff(t *testing.T) { sequential: true, instantCapability: InstantDDLCapabilityIrrelevant, }, + { + name: "create two tables valid fk cycle", + toQueries: append( + createQueries, + "create table t11 (id int primary key, i int, constraint f1101 foreign key (i) references t12 (id) on delete restrict);", + "create table t12 (id int primary key, i int, constraint f1201 foreign key (i) references t11 (id) on delete set null);", + ), + expectDiffs: 2, + expectDeps: 2, + sequential: true, + fkStrategy: ForeignKeyCheckStrategyStrict, + expectOrderedError: "no valid applicable order for diffs", + }, + { + name: "create two tables valid fk cycle, fk ignore", + toQueries: append( + createQueries, + "create table t12 (id int primary key, i int, constraint f1201 foreign key (i) references t11 (id) on delete set null);", + "create table t11 (id int primary key, i int, constraint f1101 foreign key (i) references t12 (id) on delete restrict);", + ), + expectDiffs: 2, + expectDeps: 2, + entityOrder: []string{"t11", "t12"}, // Note that the tables were reordered lexicographically + sequential: true, + instantCapability: InstantDDLCapabilityIrrelevant, + fkStrategy: ForeignKeyCheckStrategyIgnore, + }, { name: "add FK", toQueries: []string{ @@ -650,6 +680,50 @@ func TestSchemaDiff(t *testing.T) { entityOrder: []string{"tp", "t2"}, instantCapability: InstantDDLCapabilityImpossible, }, + { + name: "add two valid fk cycle references", + toQueries: []string{ + "create table t1 (id int primary key, info int not null, i int, constraint f1 foreign key (i) references t2 (id) on delete restrict);", + "create table t2 (id int primary key, ts timestamp, i int, constraint f2 foreign key (i) references t1 (id) on delete set null);", + "create view v1 as select id from t1", + }, + expectDiffs: 2, + expectDeps: 2, + sequential: false, + fkStrategy: ForeignKeyCheckStrategyStrict, + entityOrder: []string{"t1", "t2"}, + instantCapability: InstantDDLCapabilityImpossible, + }, + { + name: "add a table and a valid fk cycle references", + toQueries: []string{ + "create table t0 (id int primary key, info int not null, i int, constraint f1 foreign key (i) references t2 (id) on delete restrict);", + "create table t1 (id int primary key, info int not null);", + "create table t2 (id int primary key, ts timestamp, i int, constraint f2 foreign key (i) references t0 (id) on delete set null);", + "create view v1 as select id from t1", + }, + expectDiffs: 2, + expectDeps: 2, + sequential: true, + fkStrategy: ForeignKeyCheckStrategyStrict, + entityOrder: []string{"t0", "t2"}, + instantCapability: InstantDDLCapabilityImpossible, + }, + { + name: "add a table and a valid fk cycle references, lelxicographically desc", + toQueries: []string{ + "create table t1 (id int primary key, info int not null);", + "create table t2 (id int primary key, ts timestamp, i int, constraint f2 foreign key (i) references t9 (id) on delete set null);", + "create table t9 (id int primary key, info int not null, i int, constraint f1 foreign key (i) references t2 (id) on delete restrict);", + "create view v1 as select id from t1", + }, + expectDiffs: 2, + expectDeps: 2, + sequential: true, + fkStrategy: ForeignKeyCheckStrategyStrict, + entityOrder: []string{"t9", "t2"}, + instantCapability: InstantDDLCapabilityImpossible, + }, { name: "add FK, unrelated alter", toQueries: []string{ @@ -934,7 +1008,13 @@ func TestSchemaDiff(t *testing.T) { require.NoError(t, err) require.NotNil(t, toSchema) - schemaDiff, err := fromSchema.SchemaDiff(toSchema, baseHints) + hints := *baseHints + hints.ForeignKeyCheckStrategy = tc.fkStrategy + schemaDiff, err := fromSchema.SchemaDiff(toSchema, &hints) + if tc.expectError != "" { + assert.ErrorContains(t, err, tc.expectError) + return + } require.NoError(t, err) allDiffs := schemaDiff.UnorderedDiffs() @@ -953,6 +1033,10 @@ func TestSchemaDiff(t *testing.T) { assert.Equal(t, tc.sequential, schemaDiff.HasSequentialExecutionDependencies()) orderedDiffs, err := schemaDiff.OrderedDiffs(ctx) + if tc.expectOrderedError != "" { + assert.ErrorContains(t, err, tc.expectOrderedError) + return + } if tc.conflictingDiffs > 0 { assert.Error(t, err) impossibleOrderErr, ok := err.(*ImpossibleApplyDiffOrderError) diff --git a/go/vt/schemadiff/schema_test.go b/go/vt/schemadiff/schema_test.go index a979e521216..19a1b95e186 100644 --- a/go/vt/schemadiff/schema_test.go +++ b/go/vt/schemadiff/schema_test.go @@ -310,9 +310,8 @@ func TestTableForeignKeyOrdering(t *testing.T) { func TestInvalidSchema(t *testing.T) { tt := []struct { - schema string - expectErr error - expectLoopTables int + schema string + expectErr error }{ { schema: "create table t11 (id int primary key, i int, key ix(i), constraint f11 foreign key (i) references t11(id) on delete restrict)", @@ -346,55 +345,77 @@ func TestInvalidSchema(t *testing.T) { }, { // t12<->t11 - schema: "create table t11 (id int primary key, i int, constraint f11 foreign key (i) references t12 (id) on delete restrict); create table t12 (id int primary key, i int, constraint f12 foreign key (i) references t11 (id) on delete restrict)", + schema: ` + create table t11 (id int primary key, i int, constraint f1103 foreign key (i) references t12 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1203 foreign key (i) references t11 (id) on delete restrict) + `, + }, + { + // t12<->t11 + schema: ` + create table t11 (id int primary key, i int, constraint f1101 foreign key (i) references t12 (i) on delete restrict); + create table t12 (id int primary key, i int, constraint f1201 foreign key (i) references t11 (i) on delete set null) + `, expectErr: errors.Join( - &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t11", "t12", "t11"}}, + &ForeignKeyLoopError{Table: "t11", Loop: []*ForeignKeyTableColumns{{"t11", []string{"i"}}, {"t12", []string{"i"}}, {"t11", []string{"i"}}}}, + &ForeignKeyLoopError{Table: "t12", Loop: []*ForeignKeyTableColumns{{"t12", []string{"i"}}, {"t11", []string{"i"}}, {"t12", []string{"i"}}}}, ), - expectLoopTables: 2, }, { // t10, t12<->t11 - schema: "create table t10(id int primary key); create table t11 (id int primary key, i int, constraint f11 foreign key (i) references t12 (id) on delete restrict); create table t12 (id int primary key, i int, constraint f12 foreign key (i) references t11 (id) on delete restrict)", - expectErr: errors.Join( - &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t11", "t12", "t11"}}, - ), - expectLoopTables: 2, + schema: ` + create table t10(id int primary key); + create table t11 (id int primary key, i int, constraint f1102 foreign key (i) references t12 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1202 foreign key (i) references t11 (id) on delete restrict) + `, }, { // t10, t12<->t11<-t13 - schema: "create table t10(id int primary key); create table t11 (id int primary key, i int, constraint f11 foreign key (i) references t12 (id) on delete restrict); create table t12 (id int primary key, i int, constraint f12 foreign key (i) references t11 (id) on delete restrict); create table t13 (id int primary key, i int, constraint f13 foreign key (i) references t11 (id) on delete restrict)", - expectErr: errors.Join( - &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t13", Loop: []string{"t11", "t12", "t11"}}, - ), - expectLoopTables: 3, + schema: ` + create table t10(id int primary key); + create table t11 (id int primary key, i int, constraint f1104 foreign key (i) references t12 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1204 foreign key (i) references t11 (id) on delete restrict); + create table t13 (id int primary key, i int, constraint f13 foreign key (i) references t11 (id) on delete restrict)`, }, { // t10 // ^ // | //t12<->t11<-t13 - schema: "create table t10(id int primary key); create table t11 (id int primary key, i int, i10 int, constraint f11 foreign key (i) references t12 (id) on delete restrict, constraint f1110 foreign key (i10) references t10 (id) on delete restrict); create table t12 (id int primary key, i int, constraint f12 foreign key (i) references t11 (id) on delete restrict); create table t13 (id int primary key, i int, constraint f13 foreign key (i) references t11 (id) on delete restrict)", + schema: ` + create table t10(id int primary key); + create table t11 (id int primary key, i int, i10 int, constraint f111205 foreign key (i) references t12 (id) on delete restrict, constraint f111005 foreign key (i10) references t10 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1205 foreign key (id) references t11 (i) on delete restrict); + create table t13 (id int primary key, i int, constraint f1305 foreign key (i) references t11 (id) on delete restrict) + `, expectErr: errors.Join( - &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t13", Loop: []string{"t11", "t12", "t11"}}, + &ForeignKeyLoopError{Table: "t11", Loop: []*ForeignKeyTableColumns{{"t11", []string{"i"}}, {"t12", []string{"id"}}, {"t11", []string{"i"}}}}, + &ForeignKeyLoopError{Table: "t12", Loop: []*ForeignKeyTableColumns{{"t12", []string{"id"}}, {"t11", []string{"i"}}, {"t12", []string{"id"}}}}, ), - expectLoopTables: 3, }, { // t10, t12<->t11<-t13<-t14 - schema: "create table t10(id int primary key); create table t11 (id int primary key, i int, i10 int, constraint f11 foreign key (i) references t12 (id) on delete restrict, constraint f1110 foreign key (i10) references t10 (id) on delete restrict); create table t12 (id int primary key, i int, constraint f12 foreign key (i) references t11 (id) on delete restrict); create table t13 (id int primary key, i int, constraint f13 foreign key (i) references t11 (id) on delete restrict); create table t14 (id int primary key, i int, constraint f14 foreign key (i) references t13 (id) on delete restrict)", + schema: ` + create table t10(id int primary key); + create table t11 (id int primary key, i int, i10 int, constraint f1106 foreign key (i) references t12 (id) on delete restrict, constraint f111006 foreign key (i10) references t10 (id) on delete restrict); + create table t12 (id int primary key, i int, constraint f1206 foreign key (i) references t11 (id) on delete restrict); + create table t13 (id int primary key, i int, constraint f1306 foreign key (i) references t11 (id) on delete restrict); + create table t14 (id int primary key, i int, constraint f1406 foreign key (i) references t13 (id) on delete restrict) + `, + }, + { + // t10, t12<-t11<-t13<-t12 + schema: ` + create table t10(id int primary key); + create table t11 (id int primary key, i int, key i_idx (i), i10 int, constraint f1107 foreign key (i) references t12 (id), constraint f111007 foreign key (i10) references t10 (id)); + create table t12 (id int primary key, i int, key i_idx (i), constraint f1207 foreign key (id) references t13 (i)); + create table t13 (id int primary key, i int, key i_idx (i), constraint f1307 foreign key (i) references t11 (i)); + `, expectErr: errors.Join( - &ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t12", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t13", Loop: []string{"t11", "t12", "t11"}}, - &ForeignKeyLoopError{Table: "t14", Loop: []string{"t11", "t12", "t11"}}, + &ForeignKeyLoopError{Table: "t11", Loop: []*ForeignKeyTableColumns{{"t11", []string{"i"}}, {"t13", []string{"i"}}, {"t12", []string{"id"}}, {"t11", []string{"i"}}}}, + &ForeignKeyLoopError{Table: "t12", Loop: []*ForeignKeyTableColumns{{"t12", []string{"id"}}, {"t11", []string{"i"}}, {"t13", []string{"i"}}, {"t12", []string{"id"}}}}, + &ForeignKeyLoopError{Table: "t13", Loop: []*ForeignKeyTableColumns{{"t13", []string{"i"}}, {"t12", []string{"id"}}, {"t11", []string{"i"}}, {"t13", []string{"i"}}}}, ), - expectLoopTables: 4, }, { schema: "create table t11 (id int primary key, i int, key ix(i), constraint f11 foreign key (i) references t11(id2) on delete restrict)", @@ -468,14 +489,13 @@ func TestInvalidSchema(t *testing.T) { for _, ts := range tt { t.Run(ts.schema, func(t *testing.T) { - s, err := NewSchemaFromSQL(NewTestEnv(), ts.schema) + _, err := NewSchemaFromSQL(NewTestEnv(), ts.schema) if ts.expectErr == nil { assert.NoError(t, err) } else { assert.Error(t, err) assert.EqualError(t, err, ts.expectErr.Error()) } - assert.Equal(t, ts.expectLoopTables, len(s.foreignKeyLoopMap)) }) } } @@ -492,7 +512,7 @@ func TestInvalidTableForeignKeyReference(t *testing.T) { // Even though there's errors, we still expect the schema to have been created. assert.NotNil(t, s) // Even though t11 caused an error, we still expect the schema to have parsed all tables. - assert.Equal(t, 3, len(s.Entities())) + assert.Equalf(t, 3, len(s.Entities()), "found: %+v", s.EntityNames()) t11 := s.Table("t11") assert.NotNil(t, t11) // validate t11 table definition is complete, even though it was invalid. @@ -506,10 +526,20 @@ func TestInvalidTableForeignKeyReference(t *testing.T) { "create table t12 (id int primary key, i int, constraint f13 foreign key (i) references t13(id) on delete restrict)", } _, err := NewSchemaFromQueries(NewTestEnv(), fkQueries) + assert.NoError(t, err) + } + { + fkQueries := []string{ + "create table t13 (id int primary key, i int, constraint f11 foreign key (i) references t11(i) on delete restrict)", + "create table t11 (id int primary key, i int, constraint f12 foreign key (i) references t12(i) on delete restrict)", + "create table t12 (id int primary key, i int, constraint f13 foreign key (i) references t13(i) on delete restrict)", + } + _, err := NewSchemaFromQueries(NewTestEnv(), fkQueries) assert.Error(t, err) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t11", Loop: []string{"t11", "t12", "t13", "t11"}}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t12", Loop: []string{"t11", "t12", "t13", "t11"}}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t13", Loop: []string{"t11", "t12", "t13", "t11"}}).Error()) + + assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t11", Loop: []*ForeignKeyTableColumns{{"t11", []string{"i"}}, {"t13", []string{"i"}}, {"t12", []string{"i"}}, {"t11", []string{"i"}}}}).Error()) + assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t12", Loop: []*ForeignKeyTableColumns{{"t12", []string{"i"}}, {"t11", []string{"i"}}, {"t13", []string{"i"}}, {"t12", []string{"i"}}}}).Error()) + assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t13", Loop: []*ForeignKeyTableColumns{{"t13", []string{"i"}}, {"t12", []string{"i"}}, {"t11", []string{"i"}}, {"t13", []string{"i"}}}}).Error()) } { fkQueries := []string{ @@ -520,8 +550,6 @@ func TestInvalidTableForeignKeyReference(t *testing.T) { _, err := NewSchemaFromQueries(NewTestEnv(), fkQueries) assert.Error(t, err) assert.ErrorContains(t, err, (&ForeignKeyNonexistentReferencedTableError{Table: "t11", ReferencedTable: "t0"}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyDependencyUnresolvedError{Table: "t12"}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyDependencyUnresolvedError{Table: "t13"}).Error()) } { fkQueries := []string{ @@ -532,8 +560,6 @@ func TestInvalidTableForeignKeyReference(t *testing.T) { _, err := NewSchemaFromQueries(NewTestEnv(), fkQueries) assert.Error(t, err) assert.ErrorContains(t, err, (&ForeignKeyNonexistentReferencedTableError{Table: "t11", ReferencedTable: "t0"}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t12", Loop: []string{"t12", "t13", "t12"}}).Error()) - assert.ErrorContains(t, err, (&ForeignKeyLoopError{Table: "t13", Loop: []string{"t12", "t13", "t12"}}).Error()) } } @@ -943,7 +969,7 @@ func TestMassiveSchema(t *testing.T) { }) t.Run("evaluating diff", func(t *testing.T) { - schemaDiff, err := schema0.SchemaDiff(schema1, &DiffHints{}) + schemaDiff, err := schema0.SchemaDiff(schema1, EmptyDiffHints()) require.NoError(t, err) diffs := schemaDiff.UnorderedDiffs() require.NotEmpty(t, diffs) diff --git a/go/vt/schemadiff/types.go b/go/vt/schemadiff/types.go index a4edb09ec9b..b42408376b8 100644 --- a/go/vt/schemadiff/types.go +++ b/go/vt/schemadiff/types.go @@ -17,6 +17,9 @@ limitations under the License. package schemadiff import ( + "strings" + + "vitess.io/vitess/go/sqlescape" "vitess.io/vitess/go/vt/sqlparser" ) @@ -124,6 +127,11 @@ const ( EnumReorderStrategyReject ) +const ( + ForeignKeyCheckStrategyStrict int = iota + ForeignKeyCheckStrategyIgnore +) + // DiffHints is an assortment of rules for diffing entities type DiffHints struct { StrictIndexOrdering bool @@ -137,6 +145,11 @@ type DiffHints struct { TableQualifierHint int AlterTableAlgorithmStrategy int EnumReorderStrategy int + ForeignKeyCheckStrategy int +} + +func EmptyDiffHints() *DiffHints { + return &DiffHints{} } const ( @@ -144,3 +157,21 @@ const ( ApplyDiffsInOrder = "ApplyDiffsInOrder" ApplyDiffsSequential = "ApplyDiffsSequential" ) + +type ForeignKeyTableColumns struct { + Table string + Columns []string +} + +func (f ForeignKeyTableColumns) Escaped() string { + var b strings.Builder + b.WriteString(sqlescape.EscapeID(f.Table)) + b.WriteString(" (") + escapedColumns := make([]string, len(f.Columns)) + for i, column := range f.Columns { + escapedColumns[i] = sqlescape.EscapeID(column) + } + b.WriteString(strings.Join(escapedColumns, ", ")) + b.WriteString(")") + return b.String() +} diff --git a/go/vt/schemadiff/view_test.go b/go/vt/schemadiff/view_test.go index e5be9055970..2aade1dc3e8 100644 --- a/go/vt/schemadiff/view_test.go +++ b/go/vt/schemadiff/view_test.go @@ -145,7 +145,7 @@ func TestCreateViewDiff(t *testing.T) { cdiff: "ALTER ALGORITHM = TEMPTABLE VIEW `v1` AS SELECT `a` FROM `t`", }, } - hints := &DiffHints{} + hints := EmptyDiffHints() env := NewTestEnv() for _, ts := range tt { t.Run(ts.name, func(t *testing.T) {