Skip to content

Commit

Permalink
schemadiff: support valid foreign key cycles (#15431)
Browse files Browse the repository at this point in the history
Signed-off-by: Shlomi Noach <[email protected]>
  • Loading branch information
shlomi-noach authored Mar 11, 2024
1 parent 4c70c7e commit 46975b2
Show file tree
Hide file tree
Showing 10 changed files with 420 additions and 123 deletions.
66 changes: 53 additions & 13 deletions go/vt/graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,28 @@ package graph

import (
"fmt"
"maps"
"slices"
"strings"
)

const (
white int = iota
grey
black
)

// Graph is a generic graph implementation.
type Graph[C comparable] struct {
edges map[C][]C
edges map[C][]C
orderedVertices []C
}

// NewGraph creates a new graph for the given comparable type.
func NewGraph[C comparable]() *Graph[C] {
return &Graph[C]{
edges: map[C][]C{},
edges: map[C][]C{},
orderedVertices: []C{},
}
}

Expand All @@ -41,6 +50,7 @@ func (gr *Graph[C]) AddVertex(vertex C) {
return
}
gr.edges[vertex] = []C{}
gr.orderedVertices = append(gr.orderedVertices, vertex)
}

// AddEdge adds an edge to the given Graph.
Expand Down Expand Up @@ -85,35 +95,65 @@ func (gr *Graph[C]) HasCycles() bool {
color := map[C]int{}
for vertex := range gr.edges {
// If any vertex is still white, we initiate a new DFS.
if color[vertex] == 0 {
if gr.hasCyclesDfs(color, vertex) {
if color[vertex] == white {
if hasCycle, _ := gr.hasCyclesDfs(color, vertex); hasCycle {
return true
}
}
}
return false
}

// GetCycles returns all known cycles in the graph.
// It returns a map of vertices to the cycle they are part of.
// We are using a well-known DFS based colouring algorithm to check for cycles.
// Look at https://cp-algorithms.com/graph/finding-cycle.html for more details on the algorithm.
func (gr *Graph[C]) GetCycles() (vertices map[C][]C) {
// If the graph is empty, then we don't need to check anything.
if gr.Empty() {
return nil
}
vertices = make(map[C][]C)
// Initialize the coloring map.
// 0 represents white.
// 1 represents grey.
// 2 represents black.
color := map[C]int{}
for _, vertex := range gr.orderedVertices {
// If any vertex is still white, we initiate a new DFS.
if color[vertex] == white {
// We clone the colors because we wnt full coverage for all vertices.
// Otherwise, the algorithm is optimal and stop more-or-less after the first cycle.
color := maps.Clone(color)
if hasCycle, cycle := gr.hasCyclesDfs(color, vertex); hasCycle {
vertices[vertex] = cycle
}
}
}
return vertices
}

// hasCyclesDfs is a utility function for checking for cycles in a graph.
// It runs a dfs from the given vertex marking each vertex as grey. During the dfs,
// if we encounter a grey vertex, we know we have a cycle. We mark the visited vertices black
// on finishing the dfs.
func (gr *Graph[C]) hasCyclesDfs(color map[C]int, vertex C) bool {
func (gr *Graph[C]) hasCyclesDfs(color map[C]int, vertex C) (bool, []C) {
// Mark the vertex grey.
color[vertex] = 1
color[vertex] = grey
result := []C{vertex}
// Go over all the edges.
for _, end := range gr.edges[vertex] {
// If we encounter a white vertex, we continue the dfs.
if color[end] == 0 {
if gr.hasCyclesDfs(color, end) {
return true
if color[end] == white {
if hasCycle, cycle := gr.hasCyclesDfs(color, end); hasCycle {
return true, append(result, cycle...)
}
} else if color[end] == 1 {
} else if color[end] == grey {
// We encountered a grey vertex, we have a cycle.
return true
return true, append(result, end)
}
}
// Mark the vertex black before finishing
color[vertex] = 2
return false
color[vertex] = black
return false, nil
}
16 changes: 16 additions & 0 deletions go/vt/graph/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ func TestStringGraph(t *testing.T) {
wantedGraph string
wantEmpty bool
wantHasCycles bool
wantCycles map[string][]string
}{
{
name: "empty graph",
Expand Down Expand Up @@ -137,6 +138,13 @@ E - F
F - A`,
wantEmpty: false,
wantHasCycles: true,
wantCycles: map[string][]string{
"A": {"A", "B", "E", "F", "A"},
"B": {"B", "E", "F", "A", "B"},
"D": {"D", "E", "F", "A", "B", "E"},
"E": {"E", "F", "A", "B", "E"},
"F": {"F", "A", "B", "E", "F"},
},
},
}
for _, tt := range testcases {
Expand All @@ -148,6 +156,14 @@ F - A`,
require.Equal(t, tt.wantedGraph, graph.PrintGraph())
require.Equal(t, tt.wantEmpty, graph.Empty())
require.Equal(t, tt.wantHasCycles, graph.HasCycles())
if tt.wantCycles == nil {
tt.wantCycles = map[string][]string{}
}
actualCycles := graph.GetCycles()
if actualCycles == nil {
actualCycles = map[string][]string{}
}
require.Equal(t, tt.wantCycles, actualCycles)
})
}
}
51 changes: 46 additions & 5 deletions go/vt/schemadiff/diff_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ func TestDiffTables(t *testing.T) {
for _, ts := range tt {
t.Run(ts.name, func(t *testing.T) {
var fromCreateTable *sqlparser.CreateTable
hints := &DiffHints{}
hints := EmptyDiffHints()
if ts.hints != nil {
hints = ts.hints
}
Expand Down Expand Up @@ -448,7 +448,7 @@ func TestDiffViews(t *testing.T) {
name: "none",
},
}
hints := &DiffHints{}
hints := EmptyDiffHints()
env := NewTestEnv()
for _, ts := range tt {
t.Run(ts.name, func(t *testing.T) {
Expand Down Expand Up @@ -545,6 +545,7 @@ func TestDiffSchemas(t *testing.T) {
cdiffs []string
expectError string
tableRename int
fkStrategy int
}{
{
name: "identical tables",
Expand Down Expand Up @@ -799,6 +800,45 @@ func TestDiffSchemas(t *testing.T) {
"CREATE TABLE `t5` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f5` (`i`),\n\tCONSTRAINT `f5` FOREIGN KEY (`i`) REFERENCES `t7` (`id`)\n)",
},
},
{
name: "create tables with foreign keys, with invalid fk reference",
from: "create table t (id int primary key)",
to: `
create table t (id int primary key);
create table t11 (id int primary key, i int, constraint f1101a foreign key (i) references t12 (id) on delete restrict);
create table t12 (id int primary key, i int, constraint f1201a foreign key (i) references t9 (id) on delete set null);
`,
expectError: "table `t12` foreign key references nonexistent table `t9`",
},
{
name: "create tables with foreign keys, with invalid fk reference",
from: "create table t (id int primary key)",
to: `
create table t (id int primary key);
create table t11 (id int primary key, i int, constraint f1101b foreign key (i) references t12 (id) on delete restrict);
create table t12 (id int primary key, i int, constraint f1201b foreign key (i) references t9 (id) on delete set null);
`,
expectError: "table `t12` foreign key references nonexistent table `t9`",
fkStrategy: ForeignKeyCheckStrategyIgnore,
},
{
name: "create tables with foreign keys, with valid cycle",
from: "create table t (id int primary key)",
to: `
create table t (id int primary key);
create table t11 (id int primary key, i int, constraint f1101c foreign key (i) references t12 (id) on delete restrict);
create table t12 (id int primary key, i int, constraint f1201c foreign key (i) references t11 (id) on delete set null);
`,
diffs: []string{
"create table t11 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f1101c (i),\n\tconstraint f1101c foreign key (i) references t12 (id) on delete restrict\n)",
"create table t12 (\n\tid int,\n\ti int,\n\tprimary key (id),\n\tkey f1201c (i),\n\tconstraint f1201c foreign key (i) references t11 (id) on delete set null\n)",
},
cdiffs: []string{
"CREATE TABLE `t11` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f1101c` (`i`),\n\tCONSTRAINT `f1101c` FOREIGN KEY (`i`) REFERENCES `t12` (`id`) ON DELETE RESTRICT\n)",
"CREATE TABLE `t12` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`),\n\tKEY `f1201c` (`i`),\n\tCONSTRAINT `f1201c` FOREIGN KEY (`i`) REFERENCES `t11` (`id`) ON DELETE SET NULL\n)",
},
fkStrategy: ForeignKeyCheckStrategyIgnore,
},
{
name: "drop tables with foreign keys, expect specific order",
from: "create table t7(id int primary key); create table t5 (id int primary key, i int, constraint f5 foreign key (i) references t7(id)); create table t4 (id int primary key, i int, constraint f4 foreign key (i) references t7(id));",
Expand Down Expand Up @@ -932,14 +972,15 @@ func TestDiffSchemas(t *testing.T) {
for _, ts := range tt {
t.Run(ts.name, func(t *testing.T) {
hints := &DiffHints{
TableRenameStrategy: ts.tableRename,
TableRenameStrategy: ts.tableRename,
ForeignKeyCheckStrategy: ts.fkStrategy,
}
diff, err := DiffSchemasSQL(env, ts.from, ts.to, hints)
if ts.expectError != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), ts.expectError)
} else {
assert.NoError(t, err)
require.NoError(t, err)

diffs, err := diff.OrderedDiffs(ctx)
assert.NoError(t, err)
Expand Down Expand Up @@ -1024,7 +1065,7 @@ func TestSchemaApplyError(t *testing.T) {
to: "create table t(id int); create view v1 as select * from t; create view v2 as select * from t",
},
}
hints := &DiffHints{}
hints := EmptyDiffHints()
env := NewTestEnv()
for _, ts := range tt {
t.Run(ts.name, func(t *testing.T) {
Expand Down
25 changes: 20 additions & 5 deletions go/vt/schemadiff/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,31 @@ func (e *ForeignKeyDependencyUnresolvedError) Error() string {

type ForeignKeyLoopError struct {
Table string
Loop []string
Loop []*ForeignKeyTableColumns
}

func (e *ForeignKeyLoopError) Error() string {
tableIsInsideLoop := false

escaped := make([]string, len(e.Loop))
for i, t := range e.Loop {
escaped[i] = sqlescape.EscapeID(t)
if t == e.Table {
loop := e.Loop
// The tables in the loop could be e.g.:
// t1->t2->a->b->c->a
// In such case, the loop is a->b->c->a. The last item is always the head & tail of the loop.
// We want to distinguish between the case where the table is inside the loop and the case where it's outside,
// so we remove the prefix of the loop that doesn't participate in the actual cycle.
if len(loop) > 0 {
last := loop[len(loop)-1]
for i := range loop {
if loop[i].Table == last.Table {
loop = loop[i:]
break
}
}
}
escaped := make([]string, len(loop))
for i, fk := range loop {
escaped[i] = fk.Escaped()
if fk.Table == e.Table {
tableIsInsideLoop = true
}
}
Expand Down
Loading

0 comments on commit 46975b2

Please sign in to comment.