diff --git a/go/vt/schemadiff/diff.go b/go/vt/schemadiff/diff.go index d65b9631515..88eb6a7b326 100644 --- a/go/vt/schemadiff/diff.go +++ b/go/vt/schemadiff/diff.go @@ -174,3 +174,25 @@ func DiffSchemas(env *Environment, schema1 *Schema, schema2 *Schema, hints *Diff } return schema1.SchemaDiff(schema2, hints) } + +// EntityDiffByStatement is a helper function that returns a simplified and incomplete EntityDiff based on the given SQL statement. +// It is useful for testing purposes as a quick mean to wrap a statement with a diff. +func EntityDiffByStatement(statement sqlparser.Statement) EntityDiff { + switch stmt := statement.(type) { + case *sqlparser.CreateTable: + return &CreateTableEntityDiff{createTable: stmt} + case *sqlparser.RenameTable: + return &RenameTableEntityDiff{renameTable: stmt} + case *sqlparser.AlterTable: + return &AlterTableEntityDiff{alterTable: stmt} + case *sqlparser.DropTable: + return &DropTableEntityDiff{dropTable: stmt} + case *sqlparser.CreateView: + return &CreateViewEntityDiff{createView: stmt} + case *sqlparser.AlterView: + return &AlterViewEntityDiff{alterView: stmt} + case *sqlparser.DropView: + return &DropViewEntityDiff{dropView: stmt} + } + return nil +} diff --git a/go/vt/schemadiff/diff_test.go b/go/vt/schemadiff/diff_test.go index 9a74e6c0a32..7d0aa60e69c 100644 --- a/go/vt/schemadiff/diff_test.go +++ b/go/vt/schemadiff/diff_test.go @@ -1158,3 +1158,43 @@ func TestSchemaApplyError(t *testing.T) { }) } } + +func TestEntityDiffByStatement(t *testing.T) { + env := NewTestEnv() + + { + queries := []string{ + "create table t1(id int primary key)", + "alter table t1 add column i int", + "rename table t1 to t2", + "drop table t1", + "create view v1 as select * from t1", + "alter view v1 as select * from t2", + "drop view v1", + } + for _, query := range queries { + t.Run(query, func(t *testing.T) { + stmt, err := env.Parser().ParseStrictDDL(query) + require.NoError(t, err) + entityDiff := EntityDiffByStatement(stmt) + require.NotNil(t, entityDiff) + require.NotNil(t, entityDiff.Statement()) + require.Equal(t, stmt, entityDiff.Statement()) + }) + } + } + { + queries := []string{ + "drop database d1", + "optimize table t1", + } + for _, query := range queries { + t.Run(query, func(t *testing.T) { + stmt, err := env.Parser().ParseStrictDDL(query) + require.NoError(t, err) + entityDiff := EntityDiffByStatement(stmt) + require.Nil(t, entityDiff) + }) + } + } +}