diff --git a/go/test/endtoend/onlineddl/vrepl_suite/onlineddl_vrepl_suite_test.go b/go/test/endtoend/onlineddl/vrepl_suite/onlineddl_vrepl_suite_test.go index 57397ec64dd..c82b7f13a0d 100644 --- a/go/test/endtoend/onlineddl/vrepl_suite/onlineddl_vrepl_suite_test.go +++ b/go/test/endtoend/onlineddl/vrepl_suite/onlineddl_vrepl_suite_test.go @@ -65,6 +65,7 @@ const ( testFilterEnvVar = "ONLINEDDL_SUITE_TEST_FILTER" ) +// Use $VREPL_SUITE_TEST_FILTER environment variable to filter tests by name. func TestMain(m *testing.M) { defer cluster.PanicHandler(nil) flag.Parse() diff --git a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-different-pk-new-pk-column/expect_failure b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-different-pk-new-pk-column/expect_failure index ae3584915dd..5e227f16a3c 100644 --- a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-different-pk-new-pk-column/expect_failure +++ b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-different-pk-new-pk-column/expect_failure @@ -1 +1 @@ -Found no possible +found no possible diff --git a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-drop-pk/expect_failure b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-drop-pk/expect_failure index ae3584915dd..5e227f16a3c 100644 --- a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-drop-pk/expect_failure +++ b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-drop-pk/expect_failure @@ -1 +1 @@ -Found no possible +found no possible diff --git a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-float-unique-key/create.sql b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-float-unique-key/create.sql index abd7fbd4266..3712a673838 100644 --- a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-float-unique-key/create.sql +++ b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-float-unique-key/create.sql @@ -1,6 +1,6 @@ drop table if exists onlineddl_test; create table onlineddl_test ( - f float, + f float not null, i int not null, ts timestamp default current_timestamp, dt datetime, diff --git a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-float-unique-key/expect_failure b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-float-unique-key/expect_failure index ae3584915dd..5e227f16a3c 100644 --- a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-float-unique-key/expect_failure +++ b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-float-unique-key/expect_failure @@ -1 +1 @@ -Found no possible +found no possible diff --git a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-no-unique-key/expect_failure b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-no-unique-key/expect_failure index ae3584915dd..5e227f16a3c 100644 --- a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-no-unique-key/expect_failure +++ b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-no-unique-key/expect_failure @@ -1 +1 @@ -Found no possible +found no possible diff --git a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-nullable-unique-key/alter b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-nullable-unique-key/alter new file mode 100644 index 00000000000..0d2477f5801 --- /dev/null +++ b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-nullable-unique-key/alter @@ -0,0 +1 @@ +add column v varchar(32) diff --git a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-nullable-unique-key/create.sql b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-nullable-unique-key/create.sql new file mode 100644 index 00000000000..71f112d33c2 --- /dev/null +++ b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-nullable-unique-key/create.sql @@ -0,0 +1,11 @@ +drop table if exists onlineddl_test; +create table onlineddl_test ( + id int, + i int not null, + ts timestamp default current_timestamp, + dt datetime, + key i_idx(i), + unique key id_uidx(id) +) auto_increment=1; + +drop event if exists onlineddl_test; diff --git a/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-nullable-unique-key/expect_failure b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-nullable-unique-key/expect_failure new file mode 100644 index 00000000000..5e227f16a3c --- /dev/null +++ b/go/test/endtoend/onlineddl/vrepl_suite/testdata/fail-nullable-unique-key/expect_failure @@ -0,0 +1 @@ +found no possible diff --git a/go/vt/schemadiff/capability.go b/go/vt/schemadiff/capability.go index 2a3e2d97c9b..1471599d390 100644 --- a/go/vt/schemadiff/capability.go +++ b/go/vt/schemadiff/capability.go @@ -61,18 +61,6 @@ func alterOptionCapableOfInstantDDL(alterOption sqlparser.AlterOption, createTab } } - isGeneratedColumn := func(col *sqlparser.ColumnDefinition) (bool, sqlparser.ColumnStorage) { - if col == nil { - return false, 0 - } - if col.Type.Options == nil { - return false, 0 - } - if col.Type.Options.As == nil { - return false, 0 - } - return true, col.Type.Options.Storage - } colStringStrippedDown := func(col *sqlparser.ColumnDefinition, stripEnum bool) string { strippedCol := sqlparser.Clone(col) // strip `default` @@ -153,7 +141,7 @@ func alterOptionCapableOfInstantDDL(alterOption sqlparser.AlterOption, createTab return false, nil } for _, column := range opt.Columns { - if isGenerated, storage := isGeneratedColumn(column); isGenerated { + if isGenerated, storage := IsGeneratedColumn(column); isGenerated { if storage == sqlparser.StoredStorage { // Adding a generated "STORED" column is unsupported return false, nil @@ -188,7 +176,7 @@ func alterOptionCapableOfInstantDDL(alterOption sqlparser.AlterOption, createTab // not supported if the column is part of an index return false, nil } - if isGenerated, _ := isGeneratedColumn(col); isGenerated { + if isGenerated, _ := IsGeneratedColumn(col); isGenerated { // supported by all 8.0 versions // Note: according to the docs dropping a STORED generated column is not INSTANT-able, // but in practice this is supported. This is why we don't test for STORED here, like diff --git a/go/vt/schemadiff/column.go b/go/vt/schemadiff/column.go index 7e55192cb06..63181cef9cb 100644 --- a/go/vt/schemadiff/column.go +++ b/go/vt/schemadiff/column.go @@ -20,6 +20,7 @@ import ( "strings" "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/ptr" "vitess.io/vitess/go/vt/sqlparser" ) @@ -71,63 +72,100 @@ func NewModifyColumnDiffByDefinition(definition *sqlparser.ColumnDefinition) *Mo } type ColumnDefinitionEntity struct { - columnDefinition *sqlparser.ColumnDefinition + ColumnDefinition *sqlparser.ColumnDefinition + inPK bool // Does this column appear in the primary key? tableCharsetCollate *charsetCollate Env *Environment } -func NewColumnDefinitionEntity(env *Environment, c *sqlparser.ColumnDefinition, tableCharsetCollate *charsetCollate) *ColumnDefinitionEntity { +func NewColumnDefinitionEntity(env *Environment, c *sqlparser.ColumnDefinition, inPK bool, tableCharsetCollate *charsetCollate) *ColumnDefinitionEntity { return &ColumnDefinitionEntity{ - columnDefinition: c, + ColumnDefinition: c, + inPK: inPK, tableCharsetCollate: tableCharsetCollate, Env: env, } } +func (c *ColumnDefinitionEntity) Name() string { + return c.ColumnDefinition.Name.String() +} + +func (c *ColumnDefinitionEntity) NameLowered() string { + return c.ColumnDefinition.Name.Lowered() +} + func (c *ColumnDefinitionEntity) Clone() *ColumnDefinitionEntity { clone := &ColumnDefinitionEntity{ - columnDefinition: sqlparser.Clone(c.columnDefinition), + ColumnDefinition: sqlparser.Clone(c.ColumnDefinition), + inPK: c.inPK, tableCharsetCollate: c.tableCharsetCollate, Env: c.Env, } return clone } +// SetExplicitDefaultAndNull sets: +// - NOT NULL, if the columns is part of the PRIMARY KEY +// - DEFAULT NULL, if the columns is NULLable and no DEFAULT was mentioned +// Normally in schemadiff we work the opposite way: we strive to have the minimal equivalent representation +// of a definition. But this function can be used (often in conjunction with Clone()) to enrich a column definition +// so as to have explicit and authoritative view on any particular column. +func (c *ColumnDefinitionEntity) SetExplicitDefaultAndNull() { + if c.inPK { + // Any column in the primary key is implicitly NOT NULL. + c.ColumnDefinition.Type.Options.Null = ptr.Of(false) + } + if c.ColumnDefinition.Type.Options.Null == nil || *c.ColumnDefinition.Type.Options.Null { + // Nullable column, let'se see if there's already a DEFAULT. + if c.ColumnDefinition.Type.Options.Default == nil { + // nope, let's add a DEFAULT NULL + c.ColumnDefinition.Type.Options.Default = &sqlparser.NullVal{} + } + } +} + // SetExplicitCharsetCollate enriches this column definition with collation and charset. Those may be // already present, or perhaps just one of them is present (in which case we use the one to populate the other), // or both might be missing, in which case we use the table's charset/collation. +// Normally in schemadiff we work the opposite way: we strive to have the minimal equivalent representation +// of a definition. But this function can be used (often in conjunction with Clone()) to enrich a column definition +// so as to have explicit and authoritative view on any particular column. func (c *ColumnDefinitionEntity) SetExplicitCharsetCollate() error { if !c.IsTextual() { return nil } // We will now denormalize the columns charset & collate as needed (if empty, populate from table.) // Normalizing _this_ column definition: - if c.columnDefinition.Type.Charset.Name != "" && c.columnDefinition.Type.Options.Collate == "" { + if c.ColumnDefinition.Type.Charset.Name != "" && c.ColumnDefinition.Type.Options.Collate == "" { // Charset defined without collation. Assign the default collation for that charset. - collation := c.Env.CollationEnv().DefaultCollationForCharset(c.columnDefinition.Type.Charset.Name) + collation := c.Env.CollationEnv().DefaultCollationForCharset(c.ColumnDefinition.Type.Charset.Name) if collation == collations.Unknown { - return &UnknownColumnCharsetCollationError{Column: c.columnDefinition.Name.String(), Charset: c.tableCharsetCollate.charset} + return &UnknownColumnCharsetCollationError{Column: c.ColumnDefinition.Name.String(), Charset: c.tableCharsetCollate.charset} } - c.columnDefinition.Type.Options.Collate = c.Env.CollationEnv().LookupName(collation) + c.ColumnDefinition.Type.Options.Collate = c.Env.CollationEnv().LookupName(collation) } - if c.columnDefinition.Type.Charset.Name == "" && c.columnDefinition.Type.Options.Collate != "" { + if c.ColumnDefinition.Type.Charset.Name == "" && c.ColumnDefinition.Type.Options.Collate != "" { // Column has explicit collation but no charset. We can infer the charset from the collation. - collationID := c.Env.CollationEnv().LookupByName(c.columnDefinition.Type.Options.Collate) + collationID := c.Env.CollationEnv().LookupByName(c.ColumnDefinition.Type.Options.Collate) charset := c.Env.CollationEnv().LookupCharsetName(collationID) if charset == "" { - return &UnknownColumnCollationCharsetError{Column: c.columnDefinition.Name.String(), Collation: c.columnDefinition.Type.Options.Collate} + return &UnknownColumnCollationCharsetError{Column: c.ColumnDefinition.Name.String(), Collation: c.ColumnDefinition.Type.Options.Collate} } - c.columnDefinition.Type.Charset.Name = charset + c.ColumnDefinition.Type.Charset.Name = charset } - if c.columnDefinition.Type.Charset.Name == "" { + if c.ColumnDefinition.Type.Charset.Name == "" { // Still nothing? Assign the table's charset/collation. - c.columnDefinition.Type.Charset.Name = c.tableCharsetCollate.charset - if c.columnDefinition.Type.Options.Collate = c.tableCharsetCollate.collate; c.columnDefinition.Type.Options.Collate == "" { + c.ColumnDefinition.Type.Charset.Name = c.tableCharsetCollate.charset + if c.ColumnDefinition.Type.Options.Collate == "" { + c.ColumnDefinition.Type.Options.Collate = c.tableCharsetCollate.collate + } + if c.ColumnDefinition.Type.Options.Collate = c.tableCharsetCollate.collate; c.ColumnDefinition.Type.Options.Collate == "" { collation := c.Env.CollationEnv().DefaultCollationForCharset(c.tableCharsetCollate.charset) if collation == collations.Unknown { - return &UnknownColumnCharsetCollationError{Column: c.columnDefinition.Name.String(), Charset: c.tableCharsetCollate.charset} + return &UnknownColumnCharsetCollationError{Column: c.ColumnDefinition.Name.String(), Charset: c.tableCharsetCollate.charset} } - c.columnDefinition.Type.Options.Collate = c.Env.CollationEnv().LookupName(collation) + c.ColumnDefinition.Type.Options.Collate = c.Env.CollationEnv().LookupName(collation) } } return nil @@ -168,7 +206,7 @@ func (c *ColumnDefinitionEntity) ColumnDiff( } } - if sqlparser.Equals.RefOfColumnDefinition(cClone.columnDefinition, otherClone.columnDefinition) { + if sqlparser.Equals.RefOfColumnDefinition(cClone.ColumnDefinition, otherClone.ColumnDefinition) { return nil, nil } @@ -181,19 +219,228 @@ func (c *ColumnDefinitionEntity) ColumnDiff( } switch hints.EnumReorderStrategy { case EnumReorderStrategyReject: - otherEnumValuesMap := getEnumValuesMap(otherClone.columnDefinition.Type.EnumValues) - for ordinal, enumValue := range cClone.columnDefinition.Type.EnumValues { + otherEnumValuesMap := getEnumValuesMap(otherClone.ColumnDefinition.Type.EnumValues) + for ordinal, enumValue := range cClone.ColumnDefinition.Type.EnumValues { if otherOrdinal, ok := otherEnumValuesMap[enumValue]; ok { if ordinal != otherOrdinal { - return nil, &EnumValueOrdinalChangedError{Table: tableName, Column: cClone.columnDefinition.Name.String(), Value: enumValue, Ordinal: ordinal, NewOrdinal: otherOrdinal} + return nil, &EnumValueOrdinalChangedError{Table: tableName, Column: cClone.ColumnDefinition.Name.String(), Value: enumValue, Ordinal: ordinal, NewOrdinal: otherOrdinal} } } } } - return NewModifyColumnDiffByDefinition(other.columnDefinition), nil + return NewModifyColumnDiffByDefinition(other.ColumnDefinition), nil +} + +// Type returns the column's type +func (c *ColumnDefinitionEntity) Type() string { + return c.ColumnDefinition.Type.Type } // IsTextual returns true when this column is of textual type, and is capable of having a character set property func (c *ColumnDefinitionEntity) IsTextual() bool { - return charsetTypes[strings.ToLower(c.columnDefinition.Type.Type)] + return charsetTypes[strings.ToLower(c.Type())] +} + +// IsGenerated returns true when this column is generated, and indicates the storage type (virtual/stored) +func IsGeneratedColumn(col *sqlparser.ColumnDefinition) (bool, sqlparser.ColumnStorage) { + if col == nil { + return false, 0 + } + if col.Type.Options == nil { + return false, 0 + } + if col.Type.Options.As == nil { + return false, 0 + } + return true, col.Type.Options.Storage +} + +// IsGenerated returns true when this column is generated, and indicates the storage type (virtual/stored) +func (c *ColumnDefinitionEntity) IsGenerated() bool { + isGenerated, _ := IsGeneratedColumn(c.ColumnDefinition) + return isGenerated +} + +// IsNullable returns true when this column is NULLable +func (c *ColumnDefinitionEntity) IsNullable() bool { + if c.inPK { + return false + } + return c.ColumnDefinition.Type.Options.Null == nil || *c.ColumnDefinition.Type.Options.Null +} + +// IsDefaultNull returns true when this column has DEFAULT NULL +func (c *ColumnDefinitionEntity) IsDefaultNull() bool { + if !c.IsNullable() { + return false + } + _, ok := c.ColumnDefinition.Type.Options.Default.(*sqlparser.NullVal) + return ok +} + +// IsDefaultNull returns true when this column has DEFAULT NULL +func (c *ColumnDefinitionEntity) HasDefault() bool { + if c.ColumnDefinition.Type.Options.Default == nil { + return false + } + if c.IsDefaultNull() { + return true + } + return true +} + +// IsAutoIncrement returns true when this column is AUTO_INCREMENT +func (c *ColumnDefinitionEntity) IsAutoIncrement() bool { + return c.ColumnDefinition.Type.Options.Autoincrement +} + +// IsUnsigned returns true when this column is UNSIGNED +func (c *ColumnDefinitionEntity) IsUnsigned() bool { + return c.ColumnDefinition.Type.Unsigned +} + +// IsNumeric returns true when this column is a numeric type +func (c *ColumnDefinitionEntity) IsIntegralType() bool { + return IsIntegralType(c.Type()) +} + +// IsFloatingPointType returns true when this column is a floating point type +func (c *ColumnDefinitionEntity) IsFloatingPointType() bool { + return IsFloatingPointType(c.Type()) +} + +// IsDecimalType returns true when this column is a decimal type +func (c *ColumnDefinitionEntity) IsDecimalType() bool { + return IsDecimalType(c.Type()) +} + +// HasBlobTypeStorage returns true when this column is a text/blob type +func (c *ColumnDefinitionEntity) HasBlobTypeStorage() bool { + return BlobTypeStorage(c.Type()) != 0 +} + +// Charset returns the column's charset +func (c *ColumnDefinitionEntity) Charset() string { + return c.ColumnDefinition.Type.Charset.Name +} + +// Collate returns the column's collation +func (c *ColumnDefinitionEntity) Collate() string { + return c.ColumnDefinition.Type.Options.Collate +} + +func (c *ColumnDefinitionEntity) EnumValues() []string { + return c.ColumnDefinition.Type.EnumValues +} + +func (c *ColumnDefinitionEntity) HasEnumValues() bool { + return len(c.EnumValues()) > 0 +} + +// EnumValuesOrdinals returns a map of enum values to their ordinals +func (c *ColumnDefinitionEntity) EnumValuesOrdinals() map[string]int { + m := make(map[string]int, len(c.ColumnDefinition.Type.EnumValues)) + for i, enumValue := range c.ColumnDefinition.Type.EnumValues { + m[enumValue] = i + 1 + } + return m +} + +// EnumOrdinalValues returns a map of enum ordinals to their values +func (c *ColumnDefinitionEntity) EnumOrdinalValues() map[int]string { + m := make(map[int]string, len(c.ColumnDefinition.Type.EnumValues)) + for i, enumValue := range c.ColumnDefinition.Type.EnumValues { + // SET and ENUM values are 1 indexed. + m[i+1] = enumValue + } + return m +} + +// Length returns the type length (e.g. 17 for VARCHAR(17), 10 for DECIMAL(10,2), 6 for TIMESTAMP(6), etc.) +func (c *ColumnDefinitionEntity) Length() int { + if c.ColumnDefinition.Type.Length == nil { + return 0 + } + return *c.ColumnDefinition.Type.Length +} + +// Scale returns the type scale (e.g. 2 for DECIMAL(10,2)) +func (c *ColumnDefinitionEntity) Scale() int { + if c.ColumnDefinition.Type.Scale == nil { + return 0 + } + return *c.ColumnDefinition.Type.Scale +} + +// ColumnDefinitionEntityList is a formalized list of ColumnDefinitionEntity, with some +// utility functions. +type ColumnDefinitionEntityList struct { + Entities []*ColumnDefinitionEntity + byName map[string]*ColumnDefinitionEntity +} + +func NewColumnDefinitionEntityList(entities []*ColumnDefinitionEntity) *ColumnDefinitionEntityList { + list := &ColumnDefinitionEntityList{ + Entities: entities, + byName: make(map[string]*ColumnDefinitionEntity), + } + for _, entity := range entities { + list.byName[entity.Name()] = entity + list.byName[entity.NameLowered()] = entity + } + return list +} + +func (l *ColumnDefinitionEntityList) Len() int { + return len(l.Entities) +} + +// Names returns the names of all the columns in this list +func (l *ColumnDefinitionEntityList) Names() []string { + names := make([]string, len(l.Entities)) + for i, entity := range l.Entities { + names[i] = entity.Name() + } + return names +} + +// GetColumn returns the column with the given name, or nil if not found +func (l *ColumnDefinitionEntityList) GetColumn(name string) *ColumnDefinitionEntity { + return l.byName[name] +} + +// Contains returns true when this list contains all the entities from the other list +func (l *ColumnDefinitionEntityList) Contains(other *ColumnDefinitionEntityList) bool { + for _, entity := range other.Entities { + if l.GetColumn(entity.NameLowered()) == nil { + return false + } + } + return true +} + +// Union returns a new ColumnDefinitionEntityList with all the entities from this list and the other list +func (l *ColumnDefinitionEntityList) Union(other *ColumnDefinitionEntityList) *ColumnDefinitionEntityList { + entities := append(l.Entities, other.Entities...) + return NewColumnDefinitionEntityList(entities) +} + +// Clone creates a copy of this list, with copies of the entities +func (l *ColumnDefinitionEntityList) Clone() *ColumnDefinitionEntityList { + entities := make([]*ColumnDefinitionEntity, len(l.Entities)) + for i, entity := range l.Entities { + entities[i] = entity.Clone() + } + return NewColumnDefinitionEntityList(entities) +} + +// Filter returns a new subset ColumnDefinitionEntityList with only the entities that pass the filter +func (l *ColumnDefinitionEntityList) Filter(include func(entity *ColumnDefinitionEntity) bool) *ColumnDefinitionEntityList { + var entities []*ColumnDefinitionEntity + for _, entity := range l.Entities { + if include(entity) { + entities = append(entities, entity) + } + } + return NewColumnDefinitionEntityList(entities) } diff --git a/go/vt/schemadiff/column_test.go b/go/vt/schemadiff/column_test.go new file mode 100644 index 00000000000..f1b8f9e4f75 --- /dev/null +++ b/go/vt/schemadiff/column_test.go @@ -0,0 +1,497 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package schemadiff + +import ( + "fmt" + "testing" + + "golang.org/x/exp/maps" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestColumnFunctions(t *testing.T) { + table := ` + create table t ( + id int, + col1 int, + col2 int not null, + col3 int default null, + col4 int default 1, + COL5 int default 1, + ts1 timestamp, + ts2 timestamp(3) null, + ts3 timestamp(6) not null, + primary key (id) + )` + env := NewTestEnv() + createTableEntity, err := NewCreateTableEntityFromSQL(env, table) + require.NoError(t, err) + m := createTableEntity.ColumnDefinitionEntitiesMap() + for _, col := range m { + col.SetExplicitDefaultAndNull() + err := col.SetExplicitCharsetCollate() + require.NoError(t, err) + } + + t.Run("nullable", func(t *testing.T) { + assert.False(t, m["id"].IsNullable()) + assert.True(t, m["col1"].IsNullable()) + assert.False(t, m["col2"].IsNullable()) + assert.True(t, m["col3"].IsNullable()) + assert.True(t, m["col4"].IsNullable()) + assert.True(t, m["col5"].IsNullable()) + assert.True(t, m["ts1"].IsNullable()) + assert.True(t, m["ts2"].IsNullable()) + assert.False(t, m["ts3"].IsNullable()) + }) + t.Run("default null", func(t *testing.T) { + assert.False(t, m["id"].IsDefaultNull()) + assert.True(t, m["col1"].IsDefaultNull()) + assert.False(t, m["col2"].IsDefaultNull()) + assert.True(t, m["col3"].IsDefaultNull()) + assert.False(t, m["col4"].IsDefaultNull()) + assert.False(t, m["col5"].IsDefaultNull()) + assert.True(t, m["ts1"].IsDefaultNull()) + assert.True(t, m["ts2"].IsDefaultNull()) + assert.False(t, m["ts3"].IsDefaultNull()) + }) + t.Run("has default", func(t *testing.T) { + assert.False(t, m["id"].HasDefault()) + assert.True(t, m["col1"].HasDefault()) + assert.False(t, m["col2"].HasDefault()) + assert.True(t, m["col3"].HasDefault()) + assert.True(t, m["col4"].HasDefault()) + assert.True(t, m["col5"].HasDefault()) + assert.True(t, m["ts1"].HasDefault()) + assert.True(t, m["ts2"].HasDefault()) + assert.False(t, m["ts3"].HasDefault()) + }) +} + +func TestExpands(t *testing.T) { + tcases := []struct { + source string + target string + expands bool + msg string + }{ + { + source: "int", + target: "int", + }, + { + source: "int", + target: "smallint", + }, + { + source: "int", + target: "smallint unsigned", + }, + { + source: "int unsigned", + target: "tinyint", + expands: true, + msg: "source is unsigned, target is signed", + }, + { + source: "int unsigned", + target: "tinyint signed", + expands: true, + msg: "source is unsigned, target is signed", + }, + { + source: "int", + target: "tinyint", + }, + { + source: "int", + target: "bigint", + expands: true, + msg: "increased integer range", + }, + { + source: "int", + target: "bigint unsigned", + expands: true, + msg: "increased integer range", + }, + { + source: "int", + target: "int unsigned", + expands: true, + msg: "target unsigned value exceeds source unsigned value", + }, + { + source: "int unsigned", + target: "int", + expands: true, + msg: "source is unsigned, target is signed", + }, + { + source: "int", + target: "int default null", + }, + { + source: "int default null", + target: "int", + }, + { + source: "int", + target: "int not null", + }, + { + source: "int not null", + target: "int", + expands: true, + msg: "target is NULL-able, source is not", + }, + { + source: "int not null", + target: "int default null", + expands: true, + msg: "target is NULL-able, source is not", + }, + { + source: "float", + target: "int", + }, + { + source: "int", + target: "float", + expands: true, + msg: "target is floating point, source is not", + }, + { + source: "float", + target: "double", + expands: true, + msg: "increased floating point range", + }, + { + source: "decimal(5,2)", + target: "float", + expands: true, + msg: "target is floating point, source is not", + }, + { + source: "int", + target: "decimal", + expands: true, + msg: "target is decimal, source is not", + }, + { + source: "int", + target: "decimal(5,2)", + expands: true, + msg: "increased length", + }, + { + source: "int", + target: "decimal(5,0)", + expands: true, + msg: "increased length", + }, + { + source: "decimal(5,2)", // 123.45 + target: "decimal(3,2)", // 1.23 + }, + { + source: "decimal(5,2)", // 123.45 + target: "decimal(4,1)", // 123.4 + }, + { + source: "decimal(5,2)", // 123.45 + target: "decimal(5,1)", // 1234.5 + expands: true, + msg: "increased decimal range", + }, + { + source: "char(7)", + target: "char(7)", + }, + { + source: "char(7)", + target: "varchar(7)", + }, + { + source: "char(7)", + target: "varchar(5)", + }, + { + source: "char(5)", + target: "varchar(7)", + expands: true, + msg: "increased length", + }, + { + source: "varchar(5)", + target: "char(7)", + expands: true, + msg: "increased length", + }, + { + source: "tinytext", + target: "tinytext", + }, + { + source: "tinytext", + target: "tinyblob", + }, + { + source: "mediumtext", + target: "tinytext", + }, + { + source: "mediumblob", + target: "tinytext", + }, + { + source: "tinytext", + target: "text", + expands: true, + msg: "increased blob range", + }, + { + source: "tinytext", + target: "mediumblob", + expands: true, + msg: "increased blob range", + }, + { + source: "timestamp", + target: "timestamp", + }, + { + source: "timestamp", + target: "time", + }, + { + source: "datetime", + target: "timestamp", + }, + { + source: "datetime", + target: "date", + }, + { + source: "time", + target: "timestamp", + expands: true, + msg: "target is expanded data type of source", + }, + { + source: "timestamp", + target: "datetime", + expands: true, + msg: "target is expanded data type of source", + }, + { + source: "date", + target: "datetime", + expands: true, + msg: "target is expanded data type of source", + }, + { + source: "timestamp", + target: "timestamp(3)", + expands: true, + msg: "increased length", + }, + { + source: "timestamp", + target: "timestamp(6)", + expands: true, + msg: "increased length", + }, + { + source: "timestamp(3)", + target: "timestamp(6)", + expands: true, + msg: "increased length", + }, + { + source: "timestamp(6)", + target: "timestamp(3)", + }, + { + source: "timestamp(6)", + target: "timestamp", + }, + { + source: "timestamp", + target: "time(3)", + expands: true, + msg: "increased length", + }, + { + source: "datetime", + target: "time(3)", + expands: true, + msg: "increased length", + }, + { + source: "enum('a','b')", + target: "enum('a','b')", + }, + { + source: "enum('a','b')", + target: "enum('a')", + }, + { + source: "enum('a','b')", + target: "enum('b')", + expands: true, + msg: "target enum/set expands or reorders source enum/set", + }, + { + source: "enum('a','b')", + target: "enum('a','b','c')", + expands: true, + msg: "target enum/set expands or reorders source enum/set", + }, + { + source: "enum('a','b')", + target: "enum('a','x')", + expands: true, + msg: "target enum/set expands or reorders source enum/set", + }, + { + source: "set('a','b')", + target: "set('a','b')", + }, + { + source: "set('a','b')", + target: "set('a','b','c')", + expands: true, + msg: "target enum/set expands or reorders source enum/set", + }, + } + env := NewTestEnv() + for _, tcase := range tcases { + t.Run(tcase.source+" -> "+tcase.target, func(t *testing.T) { + fromCreateTableSQL := fmt.Sprintf("create table t (col %s)", tcase.source) + from, err := NewCreateTableEntityFromSQL(env, fromCreateTableSQL) + require.NoError(t, err) + + toCreateTableSQL := fmt.Sprintf("create table t (col %s)", tcase.target) + to, err := NewCreateTableEntityFromSQL(env, toCreateTableSQL) + require.NoError(t, err) + + require.Len(t, from.ColumnDefinitionEntities(), 1) + fromCol := from.ColumnDefinitionEntities()[0] + require.Len(t, to.ColumnDefinitionEntities(), 1) + toCol := to.ColumnDefinitionEntities()[0] + + expands, message := ColumnChangeExpandsDataRange(fromCol, toCol) + assert.Equal(t, tcase.expands, expands, message) + if expands { + require.NotEmpty(t, tcase.msg, message) + } + assert.Contains(t, message, tcase.msg) + }) + } +} + +func TestColumnDefinitionEntityList(t *testing.T) { + table := ` + create table t ( + id int, + col1 int, + Col2 int not null, + primary key (id) + )` + env := NewTestEnv() + createTableEntity, err := NewCreateTableEntityFromSQL(env, table) + require.NoError(t, err) + entities := createTableEntity.ColumnDefinitionEntities() + require.NotEmpty(t, entities) + list := NewColumnDefinitionEntityList(entities) + assert.NotNil(t, list.GetColumn("id")) + assert.NotNil(t, list.GetColumn("col1")) + assert.NotNil(t, list.GetColumn("Col2")) + assert.NotNil(t, list.GetColumn("col2")) // we also allow lower case + assert.Nil(t, list.GetColumn("COL2")) + assert.Nil(t, list.GetColumn("ID")) + assert.Nil(t, list.GetColumn("Col1")) + assert.Nil(t, list.GetColumn("col3")) +} + +func TestColumnDefinitionEntityListSubset(t *testing.T) { + table1 := ` + create table t ( + ID int, + col1 int, + Col2 int not null, + primary key (id) + )` + table2 := ` + create table t ( + id int, + Col1 int, + primary key (id) + )` + env := NewTestEnv() + createTableEntity1, err := NewCreateTableEntityFromSQL(env, table1) + require.NoError(t, err) + entities1 := createTableEntity1.ColumnDefinitionEntities() + require.NotEmpty(t, entities1) + list1 := NewColumnDefinitionEntityList(entities1) + + createTableEntity2, err := NewCreateTableEntityFromSQL(env, table2) + require.NoError(t, err) + entities2 := createTableEntity2.ColumnDefinitionEntities() + require.NotEmpty(t, entities2) + list2 := NewColumnDefinitionEntityList(entities2) + + assert.True(t, list1.Contains(list2)) + assert.False(t, list2.Contains(list1)) +} + +func TestColumnDefinitionEntity(t *testing.T) { + table1 := ` + create table t ( + it int, + e enum('a','b','c'), + primary key (id) + )` + env := NewTestEnv() + createTableEntity1, err := NewCreateTableEntityFromSQL(env, table1) + require.NoError(t, err) + entities1 := createTableEntity1.ColumnDefinitionEntities() + require.NotEmpty(t, entities1) + list1 := NewColumnDefinitionEntityList(entities1) + + t.Run("enum", func(t *testing.T) { + enumCol := list1.GetColumn("e") + require.NotNil(t, enumCol) + assert.Equal(t, []string{"'a'", "'b'", "'c'"}, enumCol.EnumValues()) + + { + ordinalsMap := enumCol.EnumValuesOrdinals() + assert.ElementsMatch(t, []int{1, 2, 3}, maps.Values(ordinalsMap)) + assert.ElementsMatch(t, []string{"'a'", "'b'", "'c'"}, maps.Keys(ordinalsMap)) + } + { + valuesMap := enumCol.EnumOrdinalValues() + assert.ElementsMatch(t, []int{1, 2, 3}, maps.Keys(valuesMap)) + assert.ElementsMatch(t, []string{"'a'", "'b'", "'c'"}, maps.Values(valuesMap)) + } + }) +} diff --git a/go/vt/schemadiff/key.go b/go/vt/schemadiff/key.go new file mode 100644 index 00000000000..865073a5a98 --- /dev/null +++ b/go/vt/schemadiff/key.go @@ -0,0 +1,162 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package schemadiff + +import ( + "vitess.io/vitess/go/vt/sqlparser" +) + +// IndexDefinitionEntity represents an index definition in a CREATE TABLE statement, +// and includes the list of columns that are part of the index. +type IndexDefinitionEntity struct { + IndexDefinition *sqlparser.IndexDefinition + ColumnList *ColumnDefinitionEntityList + Env *Environment +} + +func NewIndexDefinitionEntity(env *Environment, indexDefinition *sqlparser.IndexDefinition, columnDefinitionEntitiesList *ColumnDefinitionEntityList) *IndexDefinitionEntity { + return &IndexDefinitionEntity{ + IndexDefinition: indexDefinition, + ColumnList: columnDefinitionEntitiesList, + Env: env, + } +} + +func (i *IndexDefinitionEntity) Name() string { + return i.IndexDefinition.Info.Name.String() +} + +func (i *IndexDefinitionEntity) NameLowered() string { + return i.IndexDefinition.Info.Name.Lowered() +} + +// Clone returns a copy of this list, with copies of all the entities. +func (i *IndexDefinitionEntity) Clone() *IndexDefinitionEntity { + clone := &IndexDefinitionEntity{ + IndexDefinition: sqlparser.Clone(i.IndexDefinition), + ColumnList: i.ColumnList.Clone(), + Env: i.Env, + } + return clone +} + +func (i *IndexDefinitionEntity) Len() int { + return len(i.IndexDefinition.Columns) +} + +// IsPrimary returns true if the index is a primary key. +func (i *IndexDefinitionEntity) IsPrimary() bool { + return i.IndexDefinition.Info.Type == sqlparser.IndexTypePrimary +} + +// IsUnique returns true if the index is a unique key. +func (i *IndexDefinitionEntity) IsUnique() bool { + return i.IndexDefinition.Info.IsUnique() +} + +// HasNullable returns true if any of the columns in the index are nullable. +func (i *IndexDefinitionEntity) HasNullable() bool { + for _, col := range i.ColumnList.Entities { + if col.IsNullable() { + return true + } + } + return false +} + +// HasFloat returns true if any of the columns in the index are floating point types. +func (i *IndexDefinitionEntity) HasFloat() bool { + for _, col := range i.ColumnList.Entities { + if col.IsFloatingPointType() { + return true + } + } + return false +} + +// HasColumnPrefix returns true if any of the columns in the index have a length prefix. +func (i *IndexDefinitionEntity) HasColumnPrefix() bool { + for _, col := range i.IndexDefinition.Columns { + if col.Length != nil { + return true + } + } + return false +} + +// ColumnNames returns the names of the columns in the index. +func (i *IndexDefinitionEntity) ColumnNames() []string { + names := make([]string, 0, len(i.IndexDefinition.Columns)) + for _, col := range i.IndexDefinition.Columns { + names = append(names, col.Column.String()) + } + return names +} + +// ContainsColumns returns true if the index contains all the columns in the given list. +func (i *IndexDefinitionEntity) ContainsColumns(columns *ColumnDefinitionEntityList) bool { + return i.ColumnList.Contains(columns) +} + +// CoveredByColumns returns true if the index is covered by the given list of columns. +func (i *IndexDefinitionEntity) CoveredByColumns(columns *ColumnDefinitionEntityList) bool { + return columns.Contains(i.ColumnList) +} + +// IndexDefinitionEntityList is a formalized list of IndexDefinitionEntity objects with a few +// utility methods. +type IndexDefinitionEntityList struct { + Entities []*IndexDefinitionEntity +} + +func NewIndexDefinitionEntityList(entities []*IndexDefinitionEntity) *IndexDefinitionEntityList { + return &IndexDefinitionEntityList{ + Entities: entities, + } +} + +func (l *IndexDefinitionEntityList) Len() int { + return len(l.Entities) +} + +// Names returns the names of the indexes in the list. +func (l *IndexDefinitionEntityList) Names() []string { + names := make([]string, len(l.Entities)) + for i, entity := range l.Entities { + names[i] = entity.Name() + } + return names +} + +// SubsetCoveredByColumns returns a new list of indexes that are covered by the given list of columns. +func (l *IndexDefinitionEntityList) SubsetCoveredByColumns(columns *ColumnDefinitionEntityList) *IndexDefinitionEntityList { + var subset []*IndexDefinitionEntity + for _, entity := range l.Entities { + if entity.CoveredByColumns(columns) { + subset = append(subset, entity) + } + } + return NewIndexDefinitionEntityList(subset) +} + +// First returns the first index in the list, or nil if the list is empty. +func (l *IndexDefinitionEntityList) First() *IndexDefinitionEntity { + if len(l.Entities) == 0 { + return nil + } + return l.Entities[0] +} diff --git a/go/vt/schemadiff/key_test.go b/go/vt/schemadiff/key_test.go new file mode 100644 index 00000000000..f11d5589ab3 --- /dev/null +++ b/go/vt/schemadiff/key_test.go @@ -0,0 +1,185 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package schemadiff + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestIndexDefinitionEntityMap(t *testing.T) { + table := ` + create table t ( + id int, + col1 int, + Col2 int not null, + col3 int not null default 3, + f float not null, + v varchar(32), + primary key (id), + unique key ukid (id), + unique key uk1 (col1), + unique key uk2 (Col2), + unique key uk3 (col3), + key k1 (col1), + key k2 (Col2), + key k3 (col3), + key kf (f), + key kf2 (f, Col2), + key kv (v), + key kv1 (v, col1), + key kv2 (v(10), Col2), + unique key uk12 (col1, Col2), + unique key uk21 (col2, Col1), + unique key uk23 (col2, col3), + unique key ukid3 (id, col3) + )` + tcases := []struct { + key string + unique bool + columns []string + nullable bool + float bool + prefix bool + }{ + { + key: "primary", + unique: true, + columns: []string{"id"}, + nullable: false, + }, + { + key: "ukid", + unique: true, + columns: []string{"id"}, + nullable: false, + }, + { + key: "uk1", + unique: true, + columns: []string{"col1"}, + nullable: true, + }, + { + key: "uk2", + unique: true, + columns: []string{"Col2"}, + nullable: false, + }, + { + key: "uk3", + unique: true, + columns: []string{"col3"}, + nullable: false, + }, + { + key: "k1", + unique: false, + columns: []string{"col1"}, + nullable: true, + }, + { + key: "k2", + unique: false, + columns: []string{"Col2"}, + nullable: false, + }, + { + key: "k3", + unique: false, + columns: []string{"col3"}, + nullable: false, + }, + { + key: "kf", + unique: false, + columns: []string{"f"}, + nullable: false, + float: true, + }, + { + key: "kf2", + unique: false, + columns: []string{"f", "Col2"}, + nullable: false, + float: true, + }, + { + key: "kv", + unique: false, + columns: []string{"v"}, + nullable: true, + }, + { + key: "kv1", + unique: false, + columns: []string{"v", "col1"}, + nullable: true, + }, + { + key: "kv2", + unique: false, + columns: []string{"v", "Col2"}, + nullable: true, + prefix: true, + }, + { + key: "uk12", + unique: true, + columns: []string{"col1", "Col2"}, + nullable: true, + }, + { + key: "uk21", + unique: true, + columns: []string{"col2", "Col1"}, + nullable: true, + }, + { + key: "uk23", + unique: true, + columns: []string{"col2", "col3"}, + nullable: false, + }, + { + key: "ukid3", + unique: true, + columns: []string{"id", "col3"}, + nullable: false, + }, + } + env := NewTestEnv() + createTableEntity, err := NewCreateTableEntityFromSQL(env, table) + require.NoError(t, err) + err = createTableEntity.validate() + require.NoError(t, err) + m := createTableEntity.IndexDefinitionEntitiesMap() + require.NotEmpty(t, m) + for _, tcase := range tcases { + t.Run(tcase.key, func(t *testing.T) { + key := m[tcase.key] + require.NotNil(t, key) + assert.Equal(t, tcase.unique, key.IsUnique()) + assert.Equal(t, tcase.columns, key.ColumnNames()) + assert.Equal(t, tcase.nullable, key.HasNullable()) + assert.Equal(t, tcase.float, key.HasFloat()) + assert.Equal(t, tcase.prefix, key.HasColumnPrefix()) + }) + } +} diff --git a/go/vt/schemadiff/mysql.go b/go/vt/schemadiff/mysql.go index 624897e2e43..65adcc1b7a1 100644 --- a/go/vt/schemadiff/mysql.go +++ b/go/vt/schemadiff/mysql.go @@ -21,20 +21,26 @@ var engineCasing = map[string]string{ "MYISAM": "MyISAM", } -var integralTypes = map[string]bool{ - "tinyint": true, - "smallint": true, - "mediumint": true, - "int": true, - "bigint": true, +// integralTypes maps known integer types to their byte storage size +var integralTypes = map[string]int{ + "tinyint": 1, + "smallint": 2, + "mediumint": 3, + "int": 4, + "bigint": 8, } -var floatTypes = map[string]bool{ - "float": true, - "float4": true, - "float8": true, - "double": true, - "real": true, +var floatTypes = map[string]int{ + "float": 4, + "float4": 4, + "float8": 8, + "double": 8, + "real": 8, +} + +var decimalTypes = map[string]bool{ + "decimal": true, + "numeric": true, } var charsetTypes = map[string]bool{ @@ -48,6 +54,56 @@ var charsetTypes = map[string]bool{ "set": true, } +var blobStorageExponent = map[string]int{ + "tinyblob": 8, + "tinytext": 8, + "blob": 16, + "text": 16, + "mediumblob": 24, + "mediumtext": 24, + "longblob": 32, + "longtext": 32, +} + +func IsFloatingPointType(columnType string) bool { + _, ok := floatTypes[columnType] + return ok +} + +func FloatingPointTypeStorage(columnType string) int { + return floatTypes[columnType] +} + func IsIntegralType(columnType string) bool { + _, ok := integralTypes[columnType] + return ok +} + +func IntegralTypeStorage(columnType string) int { return integralTypes[columnType] } + +func IsDecimalType(columnType string) bool { + return decimalTypes[columnType] +} + +func BlobTypeStorage(columnType string) int { + return blobStorageExponent[columnType] +} + +// expandedDataTypes maps some known and difficult-to-compute by INFORMATION_SCHEMA data types which expand other data types. +// For example, in "date:datetime", datetime expands date because it has more precision. In "timestamp:date" date expands timestamp +// because it can contain years not covered by timestamp. +var expandedDataTypes = map[string]bool{ + "time:datetime": true, + "date:datetime": true, + "timestamp:datetime": true, + "time:timestamp": true, + "date:timestamp": true, + "timestamp:date": true, +} + +func IsExpandingDataType(sourceType string, targetType string) bool { + _, ok := expandedDataTypes[sourceType+":"+targetType] + return ok +} diff --git a/go/vt/schemadiff/onlineddl.go b/go/vt/schemadiff/onlineddl.go new file mode 100644 index 00000000000..66908e502f5 --- /dev/null +++ b/go/vt/schemadiff/onlineddl.go @@ -0,0 +1,590 @@ +/* +Copyright 2022 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package schemadiff + +import ( + "fmt" + "math" + "sort" + "strings" + + "vitess.io/vitess/go/vt/sqlparser" +) + +// ColumnChangeExpandsDataRange sees if target column has any value set/range that is impossible in source column. +func ColumnChangeExpandsDataRange(source *ColumnDefinitionEntity, target *ColumnDefinitionEntity) (bool, string) { + if target.IsNullable() && !source.IsNullable() { + return true, "target is NULL-able, source is not" + } + if target.Length() > source.Length() { + return true, "increased length" + } + if target.Scale() > source.Scale() { + return true, "increased scale" + } + if source.IsUnsigned() && !target.IsUnsigned() { + return true, "source is unsigned, target is signed" + } + if IntegralTypeStorage(target.Type()) > IntegralTypeStorage(source.Type()) && IntegralTypeStorage(source.Type()) != 0 { + return true, "increased integer range" + } + if IntegralTypeStorage(source.Type()) <= IntegralTypeStorage(target.Type()) && + !source.IsUnsigned() && target.IsUnsigned() { + // e.g. INT SIGNED => INT UNSIGNED, INT SIGNED => BIGINT UNSIGNED + return true, "target unsigned value exceeds source unsigned value" + } + if FloatingPointTypeStorage(target.Type()) > FloatingPointTypeStorage(source.Type()) && FloatingPointTypeStorage(source.Type()) != 0 { + return true, "increased floating point range" + } + if target.IsFloatingPointType() && !source.IsFloatingPointType() { + return true, "target is floating point, source is not" + } + if target.IsDecimalType() && !source.IsDecimalType() { + return true, "target is decimal, source is not" + } + if target.IsDecimalType() && source.IsDecimalType() { + if target.Length()-target.Scale() > source.Length()-source.Scale() { + return true, "increased decimal range" + } + } + if IsExpandingDataType(source.Type(), target.Type()) { + return true, "target is expanded data type of source" + } + if BlobTypeStorage(target.Type()) > BlobTypeStorage(source.Type()) && BlobTypeStorage(source.Type()) != 0 { + return true, "increased blob range" + } + if source.Charset() != target.Charset() { + if target.Charset() == "utf8mb4" { + return true, "expand character set to utf8mb4" + } + if strings.HasPrefix(target.Charset(), "utf8") && !strings.HasPrefix(source.Charset(), "utf8") { + // not utf to utf + return true, "expand character set to utf8" + } + } + for _, colType := range []string{"enum", "set"} { + // enums and sets have very similar properties, and are practically identical in our analysis + if source.Type() == colType { + // this is an enum or a set + if target.Type() != colType { + return true, "conversion from enum/set to non-enum/set adds potential values" + } + // target is an enum or a set. See if all values on target exist in source + sourceEnumTokensMap := source.EnumOrdinalValues() + targetEnumTokensMap := target.EnumOrdinalValues() + for k, v := range targetEnumTokensMap { + if sourceEnumTokensMap[k] != v { + return true, "target enum/set expands or reorders source enum/set" + } + } + } + } + return false, "" +} + +// IsValidIterationKey returns true if the key is eligible for Online DDL iteration. +func IsValidIterationKey(key *IndexDefinitionEntity) bool { + if key == nil { + return false + } + if !key.IsUnique() { + return false + } + if key.HasFloat() { + return false + } + if key.HasColumnPrefix() { + return false + } + if key.HasNullable() { + return false + } + return true +} + +// PrioritizedUniqueKeys returns all unique keys on given table, ordered from "best" to "worst", +// for Online DDL purposes. The list of keys includes some that are not eligible for Online DDL +// iteration. +func PrioritizedUniqueKeys(createTableEntity *CreateTableEntity) *IndexDefinitionEntityList { + uniqueKeys := []*IndexDefinitionEntity{} + for _, key := range createTableEntity.IndexDefinitionEntities() { + if !key.IsUnique() { + continue + } + uniqueKeys = append(uniqueKeys, key) + } + sort.SliceStable(uniqueKeys, func(i, j int) bool { + if uniqueKeys[i].IsPrimary() { + // PRIMARY is always first + return true + } + if uniqueKeys[j].IsPrimary() { + // PRIMARY is always first + return false + } + if !uniqueKeys[i].HasNullable() && uniqueKeys[j].HasNullable() { + // Non NULLable comes first + return true + } + if uniqueKeys[i].HasNullable() && !uniqueKeys[j].HasNullable() { + // NULLable come last + return false + } + if !uniqueKeys[i].HasColumnPrefix() && uniqueKeys[j].HasColumnPrefix() { + // Non prefix comes first + return true + } + if uniqueKeys[i].HasColumnPrefix() && !uniqueKeys[j].HasColumnPrefix() { + // Prefix comes last + return false + } + iFirstColEntity := uniqueKeys[i].ColumnList.Entities[0] + jFirstColEntity := uniqueKeys[j].ColumnList.Entities[0] + if iFirstColEntity.IsIntegralType() && !jFirstColEntity.IsIntegralType() { + // Prioritize integers + return true + } + if !iFirstColEntity.IsIntegralType() && jFirstColEntity.IsIntegralType() { + // Prioritize integers + return false + } + if !iFirstColEntity.HasBlobTypeStorage() && jFirstColEntity.HasBlobTypeStorage() { + return true + } + if iFirstColEntity.HasBlobTypeStorage() && !jFirstColEntity.HasBlobTypeStorage() { + return false + } + if !iFirstColEntity.IsTextual() && jFirstColEntity.IsTextual() { + return true + } + if iFirstColEntity.IsTextual() && !jFirstColEntity.IsTextual() { + return false + } + if storageDiff := IntegralTypeStorage(iFirstColEntity.Type()) - IntegralTypeStorage(jFirstColEntity.Type()); storageDiff != 0 { + return storageDiff < 0 + } + if lenDiff := len(uniqueKeys[i].ColumnList.Entities) - len(uniqueKeys[j].ColumnList.Entities); lenDiff != 0 { + return lenDiff < 0 + } + return false + }) + return NewIndexDefinitionEntityList(uniqueKeys) +} + +// RemovedForeignKeyNames returns the names of removed foreign keys, ignoring mere name changes +func RemovedForeignKeyNames(source *CreateTableEntity, target *CreateTableEntity) (names []string, err error) { + if source == nil || target == nil { + return nil, nil + } + diffHints := DiffHints{ + ConstraintNamesStrategy: ConstraintNamesIgnoreAll, + } + diff, err := source.Diff(target, &diffHints) + if err != nil { + return nil, err + } + names = []string{} + validateWalk := func(node sqlparser.SQLNode) (kontinue bool, err error) { + switch node := node.(type) { + case *sqlparser.DropKey: + if node.Type == sqlparser.ForeignKeyType { + names = append(names, node.Name.String()) + } + } + return true, nil + } + _ = sqlparser.Walk(validateWalk, diff.Statement()) // We never return an error + return names, nil +} + +// AlterTableAnalysis contains useful Online DDL information about an AlterTable statement +type AlterTableAnalysis struct { + ColumnRenameMap map[string]string + DroppedColumnsMap map[string]bool + IsRenameTable bool + IsAutoIncrementChangeRequested bool +} + +// AnalyzeAlter looks for specific changes in the AlterTable statement, that are relevant +// to OnlineDDL/VReplication +func OnlineDDLAlterTableAnalysis(alterTable *sqlparser.AlterTable) *AlterTableAnalysis { + analysis := &AlterTableAnalysis{ + ColumnRenameMap: make(map[string]string), + DroppedColumnsMap: make(map[string]bool), + } + if alterTable == nil { + return analysis + } + for _, opt := range alterTable.AlterOptions { + switch opt := opt.(type) { + case *sqlparser.RenameTableName: + analysis.IsRenameTable = true + case *sqlparser.DropColumn: + analysis.DroppedColumnsMap[opt.Name.Name.String()] = true + case *sqlparser.ChangeColumn: + if opt.OldColumn != nil && opt.NewColDefinition != nil { + oldName := opt.OldColumn.Name.String() + newName := opt.NewColDefinition.Name.String() + analysis.ColumnRenameMap[oldName] = newName + } + case sqlparser.TableOptions: + for _, tableOption := range opt { + if strings.ToUpper(tableOption.Name) == "AUTO_INCREMENT" { + analysis.IsAutoIncrementChangeRequested = true + } + } + } + } + return analysis +} + +// GetExpandedColumnNames is given source and target shared columns, and returns the list of columns whose data type is expanded. +// An expanded data type is one where the target can have a value which the source does not. Examples: +// - any NOT NULL to NULLable (a NULL in the target cannot appear on source) +// - INT -> BIGINT (obvious) +// - BIGINT UNSIGNED -> INT SIGNED (negative values) +// - TIMESTAMP -> TIMESTAMP(3) +// etc. +func GetExpandedColumns( + sourceColumns *ColumnDefinitionEntityList, + targetColumns *ColumnDefinitionEntityList, +) ( + expandedColumns *ColumnDefinitionEntityList, + expandedDescriptions map[string]string, + err error, +) { + if len(sourceColumns.Entities) != len(targetColumns.Entities) { + return nil, nil, fmt.Errorf("source and target columns must be of same length") + } + + expandedEntities := []*ColumnDefinitionEntity{} + expandedDescriptions = map[string]string{} + for i := range sourceColumns.Entities { + // source and target columns assumed to be mapped 1:1, same length + sourceColumn := sourceColumns.Entities[i] + targetColumn := targetColumns.Entities[i] + + if isExpanded, description := ColumnChangeExpandsDataRange(sourceColumn, targetColumn); isExpanded { + expandedEntities = append(expandedEntities, sourceColumn) + expandedDescriptions[sourceColumn.Name()] = description + } + } + return NewColumnDefinitionEntityList(expandedEntities), expandedDescriptions, nil +} + +// AnalyzeSharedColumns returns the intersection of two lists of columns in same order as the first list +func AnalyzeSharedColumns( + sourceColumns, targetColumns *ColumnDefinitionEntityList, + alterTableAnalysis *AlterTableAnalysis, +) ( + sourceSharedColumns *ColumnDefinitionEntityList, + targetSharedColumns *ColumnDefinitionEntityList, + droppedSourceNonGeneratedColumns *ColumnDefinitionEntityList, + sharedColumnsMap map[string]string, +) { + sharedColumnsMap = map[string]string{} + sourceShared := []*ColumnDefinitionEntity{} + targetShared := []*ColumnDefinitionEntity{} + droppedNonGenerated := []*ColumnDefinitionEntity{} + + for _, sourceColumn := range sourceColumns.Entities { + if sourceColumn.IsGenerated() { + continue + } + isDroppedFromSource := false + // Note to a future engineer: you may be tempted to remove this loop based on the + // assumption that the later `targetColumn := targetColumns.GetColumn(expectedTargetName)` + // check is sufficient. It is not. It is possible that a columns was explicitly dropped + // and added (`DROP COLUMN c, ADD COLUMN c INT`) in the same ALTER TABLE statement. + // Without checking the ALTER TABLE statement, we would be fooled to believe that column + // `c` is unchanged in the target, when in fact it was dropped and re-added. + for droppedColumn := range alterTableAnalysis.DroppedColumnsMap { + if strings.EqualFold(sourceColumn.Name(), droppedColumn) { + isDroppedFromSource = true + break + } + } + if isDroppedFromSource { + droppedNonGenerated = append(droppedNonGenerated, sourceColumn) + // Column was dropped, hence cannot be a shared column + continue + } + expectedTargetName := sourceColumn.NameLowered() + if mappedName := alterTableAnalysis.ColumnRenameMap[sourceColumn.Name()]; mappedName != "" { + expectedTargetName = mappedName + } + targetColumn := targetColumns.GetColumn(expectedTargetName) + if targetColumn == nil { + // Column not found in target + droppedNonGenerated = append(droppedNonGenerated, sourceColumn) + continue + } + if targetColumn.IsGenerated() { + // virtual/generated columns are silently skipped. + continue + } + // OK, the column is shared (possibly renamed) between source and target. + sharedColumnsMap[sourceColumn.Name()] = targetColumn.Name() + sourceShared = append(sourceShared, sourceColumn) + targetShared = append(targetShared, targetColumn) + } + return NewColumnDefinitionEntityList(sourceShared), + NewColumnDefinitionEntityList(targetShared), + NewColumnDefinitionEntityList(droppedNonGenerated), + sharedColumnsMap +} + +// KeyAtLeastConstrainedAs returns 'true' when sourceUniqueKey is at least as constrained as targetUniqueKey. +// "More constrained" means the uniqueness constraint is "stronger". Thus, if sourceUniqueKey is as-or-more constrained than targetUniqueKey, then +// rows valid under sourceUniqueKey must also be valid in targetUniqueKey. The opposite is not necessarily so: rows that are valid in targetUniqueKey +// may cause a unique key violation under sourceUniqueKey +func KeyAtLeastConstrainedAs( + sourceUniqueKey *IndexDefinitionEntity, + targetUniqueKey *IndexDefinitionEntity, + columnRenameMap map[string]string, +) bool { + if !sourceUniqueKey.IsUnique() { + return false + } + if !targetUniqueKey.IsUnique() { + return true + } + sourceKeyLengths := map[string]int{} + for _, col := range sourceUniqueKey.IndexDefinition.Columns { + if col.Length == nil { + sourceKeyLengths[col.Column.Lowered()] = math.MaxInt64 + } else { + sourceKeyLengths[col.Column.Lowered()] = *col.Length + } + } + targetKeyLengths := map[string]int{} + for _, col := range targetUniqueKey.IndexDefinition.Columns { + if col.Length == nil { + targetKeyLengths[col.Column.Lowered()] = math.MaxInt64 + } else { + targetKeyLengths[col.Column.Lowered()] = *col.Length + } + } + // source is more constrained than target if every column in source is also in target, order is immaterial + for _, sourceCol := range sourceUniqueKey.ColumnList.Entities { + mappedColName, ok := columnRenameMap[sourceCol.Name()] + if !ok { + mappedColName = sourceCol.NameLowered() + } + targetCol := targetUniqueKey.ColumnList.GetColumn(mappedColName) + if targetCol == nil { + // source can't be more constrained if it covers *more* columns + return false + } + // We now know that sourceCol maps to targetCol + if sourceKeyLengths[sourceCol.NameLowered()] > targetKeyLengths[targetCol.NameLowered()] { + // source column covers a larger prefix than target column. It is therefore less constrained. + return false + } + } + return true +} + +// IntroducedUniqueConstraints returns the unique key constraints added in target. +// This does not necessarily mean that the unique key itself is new, +// rather that there's a new, stricter constraint on a set of columns, that didn't exist before. Example: +// +// before: +// unique key my_key (c1, c2, c3) +// after: +// unique key `other_key`(c1, c2) +// Synopsis: the constraint on (c1, c2) is new; and `other_key` in target table is considered a new key +// +// Order of columns is immaterial to uniqueness of column combination. +func IntroducedUniqueConstraints(sourceUniqueKeys *IndexDefinitionEntityList, targetUniqueKeys *IndexDefinitionEntityList, columnRenameMap map[string]string) *IndexDefinitionEntityList { + introducedUniqueConstraints := []*IndexDefinitionEntity{} + for _, targetUniqueKey := range targetUniqueKeys.Entities { + foundSourceKeyAtLeastAsConstrained := func() bool { + for _, sourceUniqueKey := range sourceUniqueKeys.Entities { + if KeyAtLeastConstrainedAs(sourceUniqueKey, targetUniqueKey, columnRenameMap) { + // target key does not add a new constraint + return true + } + } + return false + } + if !foundSourceKeyAtLeastAsConstrained() { + introducedUniqueConstraints = append(introducedUniqueConstraints, targetUniqueKey) + } + } + return NewIndexDefinitionEntityList(introducedUniqueConstraints) +} + +// RemovedUniqueConstraints returns the list of unique key constraints _removed_ going from source to target. +func RemovedUniqueConstraints(sourceUniqueKeys *IndexDefinitionEntityList, targetUniqueKeys *IndexDefinitionEntityList, columnRenameMap map[string]string) *IndexDefinitionEntityList { + reverseColumnRenameMap := map[string]string{} + for k, v := range columnRenameMap { + reverseColumnRenameMap[v] = k + } + return IntroducedUniqueConstraints(targetUniqueKeys, sourceUniqueKeys, reverseColumnRenameMap) +} + +// IterationKeysByColumns returns the Online DDL compliant unique keys from given list, +// whose columns are all covered by the given column list. +func IterationKeysByColumns(keys *IndexDefinitionEntityList, columns *ColumnDefinitionEntityList) *IndexDefinitionEntityList { + subset := []*IndexDefinitionEntity{} + for _, key := range keys.SubsetCoveredByColumns(columns).Entities { + if IsValidIterationKey(key) { + subset = append(subset, key) + } + } + return NewIndexDefinitionEntityList(subset) +} + +// MappedColumnNames +func MappedColumnNames(columnsList *ColumnDefinitionEntityList, columnNamesMap map[string]string) []string { + names := columnsList.Names() + for i := range names { + if mappedName, ok := columnNamesMap[names[i]]; ok { + names[i] = mappedName + } + } + return names +} + +// AlterTableAnalysis contains useful Online DDL information about an AlterTable statement +type MigrationTablesAnalysis struct { + SourceSharedColumns *ColumnDefinitionEntityList + TargetSharedColumns *ColumnDefinitionEntityList + DroppedNoDefaultColumns *ColumnDefinitionEntityList + ExpandedColumns *ColumnDefinitionEntityList + SharedColumnsMap map[string]string + ChosenSourceUniqueKey *IndexDefinitionEntity + ChosenTargetUniqueKey *IndexDefinitionEntity + AddedUniqueKeys *IndexDefinitionEntityList + RemovedUniqueKeys *IndexDefinitionEntityList + RemovedForeignKeyNames []string + IntToEnumMap map[string]bool + SourceAutoIncrement uint64 + RevertibleNotes []string +} + +func OnlineDDLMigrationTablesAnalysis( + sourceCreateTableEntity *CreateTableEntity, + targetCreateTableEntity *CreateTableEntity, + alterTableAnalysis *AlterTableAnalysis, +) (analysis *MigrationTablesAnalysis, err error) { + analysis = &MigrationTablesAnalysis{ + IntToEnumMap: make(map[string]bool), + RevertibleNotes: []string{}, + } + // columns: + generatedColumns := func(columns *ColumnDefinitionEntityList) *ColumnDefinitionEntityList { + return columns.Filter(func(col *ColumnDefinitionEntity) bool { + return col.IsGenerated() + }) + } + noDefaultColumns := func(columns *ColumnDefinitionEntityList) *ColumnDefinitionEntityList { + return columns.Filter(func(col *ColumnDefinitionEntity) bool { + return !col.HasDefault() + }) + } + sourceColumns := sourceCreateTableEntity.ColumnDefinitionEntitiesList() + targetColumns := targetCreateTableEntity.ColumnDefinitionEntitiesList() + + var droppedSourceNonGeneratedColumns *ColumnDefinitionEntityList + analysis.SourceSharedColumns, analysis.TargetSharedColumns, droppedSourceNonGeneratedColumns, analysis.SharedColumnsMap = AnalyzeSharedColumns(sourceColumns, targetColumns, alterTableAnalysis) + + // unique keys + sourceUniqueKeys := PrioritizedUniqueKeys(sourceCreateTableEntity) + if sourceUniqueKeys.Len() == 0 { + return nil, fmt.Errorf("found no possible unique key on `%s`", sourceCreateTableEntity.Name()) + } + + targetUniqueKeys := PrioritizedUniqueKeys(targetCreateTableEntity) + if targetUniqueKeys.Len() == 0 { + return nil, fmt.Errorf("found no possible unique key on `%s`", targetCreateTableEntity.Name()) + } + // VReplication supports completely different unique keys on source and target, covering + // some/completely different columns. The condition is that the key on source + // must use columns which all exist on target table. + eligibleSourceColumnsForUniqueKey := analysis.SourceSharedColumns.Union(generatedColumns(sourceColumns)) + analysis.ChosenSourceUniqueKey = IterationKeysByColumns(sourceUniqueKeys, eligibleSourceColumnsForUniqueKey).First() + if analysis.ChosenSourceUniqueKey == nil { + return nil, fmt.Errorf("found no possible unique key on `%s` whose columns are in target table `%s`", sourceCreateTableEntity.Name(), targetCreateTableEntity.Name()) + } + + eligibleTargetColumnsForUniqueKey := analysis.TargetSharedColumns.Union(generatedColumns(targetColumns)) + analysis.ChosenTargetUniqueKey = IterationKeysByColumns(targetUniqueKeys, eligibleTargetColumnsForUniqueKey).First() + if analysis.ChosenTargetUniqueKey == nil { + return nil, fmt.Errorf("found no possible unique key on `%s` whose columns are in source table `%s`", targetCreateTableEntity.Name(), sourceCreateTableEntity.Name()) + } + + analysis.AddedUniqueKeys = IntroducedUniqueConstraints(sourceUniqueKeys, targetUniqueKeys, alterTableAnalysis.ColumnRenameMap) + analysis.RemovedUniqueKeys = RemovedUniqueConstraints(sourceUniqueKeys, targetUniqueKeys, alterTableAnalysis.ColumnRenameMap) + analysis.RemovedForeignKeyNames, err = RemovedForeignKeyNames(sourceCreateTableEntity, targetCreateTableEntity) + if err != nil { + return nil, err + } + + formalizeColumns := func(columnsLists ...*ColumnDefinitionEntityList) error { + for _, colList := range columnsLists { + for _, col := range colList.Entities { + col.SetExplicitDefaultAndNull() + if err := col.SetExplicitCharsetCollate(); err != nil { + return err + } + } + } + return nil + } + + if err := formalizeColumns(analysis.SourceSharedColumns, analysis.TargetSharedColumns, droppedSourceNonGeneratedColumns); err != nil { + return nil, err + } + + for i := range analysis.SourceSharedColumns.Entities { + sourceColumn := analysis.SourceSharedColumns.Entities[i] + mappedColumn := analysis.TargetSharedColumns.Entities[i] + + if sourceColumn.IsIntegralType() && mappedColumn.Type() == "enum" { + analysis.IntToEnumMap[sourceColumn.Name()] = true + } + } + + analysis.DroppedNoDefaultColumns = noDefaultColumns(droppedSourceNonGeneratedColumns) + var expandedDescriptions map[string]string + analysis.ExpandedColumns, expandedDescriptions, err = GetExpandedColumns(analysis.SourceSharedColumns, analysis.TargetSharedColumns) + if err != nil { + return nil, err + } + + analysis.SourceAutoIncrement, err = sourceCreateTableEntity.AutoIncrementValue() + if err != nil { + return nil, err + } + + for _, uk := range analysis.RemovedUniqueKeys.Names() { + analysis.RevertibleNotes = append(analysis.RevertibleNotes, fmt.Sprintf("unique constraint removed: %s", uk)) + } + for _, name := range analysis.DroppedNoDefaultColumns.Names() { + analysis.RevertibleNotes = append(analysis.RevertibleNotes, fmt.Sprintf("column %s dropped, and had no default value", name)) + } + for _, name := range analysis.ExpandedColumns.Names() { + analysis.RevertibleNotes = append(analysis.RevertibleNotes, fmt.Sprintf("column %s: %s", name, expandedDescriptions[name])) + } + for _, name := range analysis.RemovedForeignKeyNames { + analysis.RevertibleNotes = append(analysis.RevertibleNotes, fmt.Sprintf("foreign key %s dropped", name)) + } + + return analysis, nil +} diff --git a/go/vt/schemadiff/onlineddl_test.go b/go/vt/schemadiff/onlineddl_test.go new file mode 100644 index 00000000000..bd08bedfe8a --- /dev/null +++ b/go/vt/schemadiff/onlineddl_test.go @@ -0,0 +1,960 @@ +/* +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package schemadiff + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/sqlparser" +) + +func TestPrioritizedUniqueKeys(t *testing.T) { + table := ` + create table t ( + idsha varchar(64), + col1 int, + col2 int not null, + col3 bigint not null default 3, + col4 smallint not null, + f float not null, + v varchar(32) not null, + primary key (idsha), + unique key ukidsha (idsha), + unique key uk1 (col1), + unique key uk2 (col2), + unique key uk3 (col3), + key k1 (col1), + key kf (f), + key k1f (col1, f), + key kv (v), + unique key ukv (v), + unique key ukvprefix (v(10)), + unique key uk2vprefix (col2, v(10)), + unique key uk1f (col1, f), + unique key uk41 (col4, col1), + unique key uk42 (col4, col2) + )` + env := NewTestEnv() + createTableEntity, err := NewCreateTableEntityFromSQL(env, table) + require.NoError(t, err) + err = createTableEntity.validate() + require.NoError(t, err) + + keys := PrioritizedUniqueKeys(createTableEntity) + require.NotEmpty(t, keys) + names := make([]string, 0, len(keys.Entities)) + for _, key := range keys.Entities { + names = append(names, key.Name()) + } + expect := []string{ + "PRIMARY", + "uk42", + "uk2", + "uk3", + "ukidsha", + "ukv", + "uk2vprefix", + "ukvprefix", + "uk41", + "uk1", + "uk1f", + } + assert.Equal(t, expect, names) +} + +func TestRemovedForeignKeyNames(t *testing.T) { + env := NewTestEnv() + + tcases := []struct { + before string + after string + names []string + }{ + { + before: "create table t (id int primary key)", + after: "create table t (id2 int primary key, i int)", + }, + { + before: "create table t (id int primary key)", + after: "create table t2 (id2 int primary key, i int)", + }, + { + before: "create table t (id int primary key, i int, constraint f foreign key (i) references parent (id) on delete cascade)", + after: "create table t (id int primary key, i int, constraint f foreign key (i) references parent (id) on delete cascade)", + }, + { + before: "create table t (id int primary key, i int, constraint f1 foreign key (i) references parent (id) on delete cascade)", + after: "create table t (id int primary key, i int, constraint f2 foreign key (i) references parent (id) on delete cascade)", + }, + { + before: "create table t (id int primary key, i int, constraint f foreign key (i) references parent (id) on delete cascade)", + after: "create table t (id int primary key, i int)", + names: []string{"f"}, + }, + { + before: "create table t (id int primary key, i int, i2 int, constraint f1 foreign key (i) references parent (id) on delete cascade, constraint fi2 foreign key (i2) references parent (id) on delete cascade)", + after: "create table t (id int primary key, i int, i2 int, constraint f2 foreign key (i) references parent (id) on delete cascade)", + names: []string{"fi2"}, + }, + { + before: "create table t1 (id int primary key, i int, constraint `check1` CHECK ((`i` < 5)))", + after: "create table t2 (id int primary key, i int)", + }, + } + for _, tcase := range tcases { + t.Run(tcase.before, func(t *testing.T) { + before, err := NewCreateTableEntityFromSQL(env, tcase.before) + require.NoError(t, err) + err = before.validate() + require.NoError(t, err) + + after, err := NewCreateTableEntityFromSQL(env, tcase.after) + require.NoError(t, err) + err = after.validate() + require.NoError(t, err) + + names, err := RemovedForeignKeyNames(before, after) + assert.NoError(t, err) + if tcase.names == nil { + tcase.names = []string{} + } + assert.Equal(t, tcase.names, names) + }) + } +} + +func TestGetAlterTableAnalysis(t *testing.T) { + tcases := []struct { + alter string + renames map[string]string + drops map[string]bool + isrename bool + autoinc bool + }{ + { + alter: "alter table t add column t int, engine=innodb", + }, + { + alter: "alter table t add column t int, change ts ts timestamp, engine=innodb", + renames: map[string]string{"ts": "ts"}, + }, + { + alter: "alter table t AUTO_INCREMENT=7", + autoinc: true, + }, + { + alter: "alter table t add column t int, change ts ts timestamp, auto_increment=7 engine=innodb", + renames: map[string]string{"ts": "ts"}, + autoinc: true, + }, + { + alter: "alter table t add column t int, change ts ts timestamp, CHANGE f `f` float, engine=innodb", + renames: map[string]string{"ts": "ts", "f": "f"}, + }, + { + alter: `alter table t add column b bigint, change f fl float, change i count int, engine=innodb`, + renames: map[string]string{"f": "fl", "i": "count"}, + }, + { + alter: "alter table t add column b bigint, change column `f` fl float, change `i` `count` int, engine=innodb", + renames: map[string]string{"f": "fl", "i": "count"}, + }, + { + alter: "alter table t add column b bigint, change column `f` fl float, change `i` `count` int, change ts ts timestamp, engine=innodb", + renames: map[string]string{"f": "fl", "i": "count", "ts": "ts"}, + }, + { + alter: "alter table t drop column b", + drops: map[string]bool{"b": true}, + }, + { + alter: "alter table t drop column b, drop key c_idx, drop column `d`", + drops: map[string]bool{"b": true, "d": true}, + }, + { + alter: "alter table t drop column b, drop key c_idx, drop column `d`, drop `e`, drop primary key, drop foreign key fk_1", + drops: map[string]bool{"b": true, "d": true, "e": true}, + }, + { + alter: "alter table t rename as something_else", + isrename: true, + }, + { + alter: "alter table t drop column b, rename as something_else", + isrename: true, + drops: map[string]bool{"b": true}, + }, + } + for _, tcase := range tcases { + t.Run(tcase.alter, func(t *testing.T) { + if tcase.renames == nil { + tcase.renames = make(map[string]string) + } + if tcase.drops == nil { + tcase.drops = make(map[string]bool) + } + stmt, err := sqlparser.NewTestParser().ParseStrictDDL(tcase.alter) + require.NoError(t, err) + alter, ok := stmt.(*sqlparser.AlterTable) + require.True(t, ok) + + analysis := OnlineDDLAlterTableAnalysis(alter) + require.NotNil(t, analysis) + assert.Equal(t, tcase.isrename, analysis.IsRenameTable) + assert.Equal(t, tcase.autoinc, analysis.IsAutoIncrementChangeRequested) + assert.Equal(t, tcase.renames, analysis.ColumnRenameMap) + assert.Equal(t, tcase.drops, analysis.DroppedColumnsMap) + }) + } +} + +func TestAnalyzeSharedColumns(t *testing.T) { + sourceTable := ` + create table t ( + id int, + cint int, + cgen1 int generated always as (cint + 1) stored, + cgen2 int generated always as (cint + 2) stored, + cchar char(1), + cremoved int not null default 7, + cnullable int, + cnodefault int not null, + extra1 int, + primary key (id) + ) + ` + targetTable := ` + create table t ( + id int, + cint int, + cgen1 int generated always as (cint + 1) stored, + cchar_alternate char(1), + cnullable int, + cnodefault int not null, + extra2 int, + primary key (id) + ) + ` + tcases := []struct { + name string + sourceTable string + targetTable string + renameMap map[string]string + expectSourceSharedColNames []string + expectTargetSharedColNames []string + expectDroppedSourceNonGeneratedColNames []string + expectSharedColumnsMap map[string]string + }{ + { + name: "rename map empty", + renameMap: map[string]string{}, + expectSourceSharedColNames: []string{"id", "cint", "cnullable", "cnodefault"}, + expectTargetSharedColNames: []string{"id", "cint", "cnullable", "cnodefault"}, + expectDroppedSourceNonGeneratedColNames: []string{"cchar", "cremoved", "extra1"}, + expectSharedColumnsMap: map[string]string{"id": "id", "cint": "cint", "cnullable": "cnullable", "cnodefault": "cnodefault"}, + }, + { + name: "renamed column", + renameMap: map[string]string{"cchar": "cchar_alternate"}, + expectSourceSharedColNames: []string{"id", "cint", "cchar", "cnullable", "cnodefault"}, + expectTargetSharedColNames: []string{"id", "cint", "cchar_alternate", "cnullable", "cnodefault"}, + expectDroppedSourceNonGeneratedColNames: []string{"cremoved", "extra1"}, + expectSharedColumnsMap: map[string]string{"id": "id", "cint": "cint", "cchar": "cchar_alternate", "cnullable": "cnullable", "cnodefault": "cnodefault"}, + }, + } + + env := NewTestEnv() + alterTableAnalysis := OnlineDDLAlterTableAnalysis(nil) // empty + for _, tcase := range tcases { + t.Run(tcase.name, func(t *testing.T) { + if tcase.sourceTable == "" { + tcase.sourceTable = sourceTable + } + if tcase.targetTable == "" { + tcase.targetTable = targetTable + } + if tcase.renameMap != nil { + alterTableAnalysis.ColumnRenameMap = tcase.renameMap + } + + sourceEntity, err := NewCreateTableEntityFromSQL(env, tcase.sourceTable) + require.NoError(t, err) + err = sourceEntity.validate() + require.NoError(t, err) + + targetEntity, err := NewCreateTableEntityFromSQL(env, tcase.targetTable) + require.NoError(t, err) + err = targetEntity.validate() + require.NoError(t, err) + + sourceSharedCols, targetSharedCols, droppedNonGeneratedCols, sharedColumnsMap := AnalyzeSharedColumns( + sourceEntity.ColumnDefinitionEntitiesList(), + targetEntity.ColumnDefinitionEntitiesList(), + alterTableAnalysis, + ) + assert.Equal(t, tcase.expectSourceSharedColNames, sourceSharedCols.Names()) + assert.Equal(t, tcase.expectTargetSharedColNames, targetSharedCols.Names()) + assert.Equal(t, tcase.expectDroppedSourceNonGeneratedColNames, droppedNonGeneratedCols.Names()) + assert.Equal(t, tcase.expectSharedColumnsMap, sharedColumnsMap) + }) + } +} + +func TestKeyAtLeastConstrainedAs(t *testing.T) { + env := NewTestEnv() + sourceTable := ` + create table source_table ( + id int, + c1 int, + c2 int, + c3 int, + c9 int, + v varchar(32), + primary key (id), + unique key uk1 (c1), + unique key uk2 (c2), + unique key uk3 (c3), + unique key uk9 (c9), + unique key uk12 (c1, c2), + unique key uk13 (c1, c3), + unique key uk23 (c2, c3), + unique key uk123 (c1, c2, c3), + unique key uk21 (c2, c1), + unique key ukv (v), + unique key ukv3 (v(3)), + unique key ukv5 (v(5)), + unique key uk2v5 (c2, v(5)) + )` + targetTable := ` + create table target_table ( + id int, + c1 int, + c2 int, + c3_renamed int, + v varchar(32), + primary key (id), + unique key uk1 (c1), + unique key uk2 (c2), + unique key uk3 (c3_renamed), + unique key uk12 (c1, c2), + unique key uk13 (c1, c3_renamed), + unique key uk23 (c2, c3_renamed), + unique key uk123 (c1, c2, c3_renamed), + unique key uk21 (c2, c1), + unique key ukv (v), + unique key ukv3 (v(3)), + unique key ukv5 (v(5)), + unique key uk2v5 (c2, v(5)) + )` + renameMap := map[string]string{ + "c3": "c3_renamed", + } + tcases := []struct { + sourceKey string + targetKey string + renameMap map[string]string + expect bool + }{ + { + sourceKey: "uk1", + targetKey: "uk1", + expect: true, + }, + { + sourceKey: "uk2", + targetKey: "uk2", + expect: true, + }, + { + sourceKey: "uk3", + targetKey: "uk3", + expect: false, // c3 is renamed + }, + { + sourceKey: "uk2", + targetKey: "uk1", + expect: false, + }, + { + sourceKey: "uk12", + targetKey: "uk1", + expect: false, + }, + { + sourceKey: "uk1", + targetKey: "uk12", + expect: true, + }, + { + sourceKey: "uk1", + targetKey: "uk21", + expect: true, + }, + { + sourceKey: "uk12", + targetKey: "uk21", + expect: true, + }, + { + sourceKey: "uk123", + targetKey: "uk21", + expect: false, + }, + { + sourceKey: "uk123", + targetKey: "uk123", + expect: false, // c3 is renamed + }, + { + sourceKey: "uk1", + targetKey: "uk123", + expect: true, // c3 is renamed but not referenced + }, + { + sourceKey: "uk21", + targetKey: "uk123", + expect: true, // c3 is renamed but not referenced + }, + { + sourceKey: "uk9", + targetKey: "uk123", + expect: false, // c9 not in target + }, + { + sourceKey: "uk3", + targetKey: "uk3", + renameMap: renameMap, + expect: true, + }, + { + sourceKey: "uk123", + targetKey: "uk123", + renameMap: renameMap, + expect: true, + }, + { + sourceKey: "uk3", + targetKey: "uk123", + renameMap: renameMap, + expect: true, + }, + { + sourceKey: "ukv", + targetKey: "ukv", + expect: true, + }, + { + sourceKey: "ukv3", + targetKey: "ukv3", + expect: true, + }, + { + sourceKey: "ukv", + targetKey: "ukv3", + expect: false, + }, + { + sourceKey: "ukv5", + targetKey: "ukv3", + expect: false, + }, + { + sourceKey: "ukv3", + targetKey: "ukv5", + expect: true, + }, + { + sourceKey: "ukv3", + targetKey: "ukv", + expect: true, + }, + { + sourceKey: "uk2", + targetKey: "uk2v5", + expect: true, + }, + { + sourceKey: "ukv5", + targetKey: "uk2v5", + expect: true, + }, + { + sourceKey: "ukv3", + targetKey: "uk2v5", + expect: true, + }, + { + sourceKey: "ukv", + targetKey: "uk2v5", + expect: false, + }, + { + sourceKey: "uk2v5", + targetKey: "ukv5", + expect: false, + }, + } + + sourceEntity, err := NewCreateTableEntityFromSQL(env, sourceTable) + require.NoError(t, err) + err = sourceEntity.validate() + require.NoError(t, err) + sourceKeys := sourceEntity.IndexDefinitionEntitiesMap() + + targetEntity, err := NewCreateTableEntityFromSQL(env, targetTable) + require.NoError(t, err) + err = targetEntity.validate() + require.NoError(t, err) + targetKeys := targetEntity.IndexDefinitionEntitiesMap() + + for _, tcase := range tcases { + t.Run(tcase.sourceKey+"/"+tcase.targetKey, func(t *testing.T) { + if tcase.renameMap == nil { + tcase.renameMap = make(map[string]string) + } + sourceKey := sourceKeys[tcase.sourceKey] + require.NotNil(t, sourceKey) + + targetKey := targetKeys[tcase.targetKey] + require.NotNil(t, targetKey) + + result := KeyAtLeastConstrainedAs(sourceKey, targetKey, tcase.renameMap) + assert.Equal(t, tcase.expect, result) + }) + } +} + +func TestIntroducedUniqueConstraints(t *testing.T) { + env := NewTestEnv() + tcases := []struct { + sourceTable string + targetTable string + expectIntroduced []string + expectRemoved []string + }{ + { + sourceTable: ` + create table source_table ( + id int, + c1 int, + c2 int, + c3 int, + primary key (id), + unique key uk1 (c1), + unique key uk2 (c2), + unique key uk31 (c3, c1), + key k1 (c1) + )`, + targetTable: ` + create table source_table ( + id int, + c1 int, + c2 int, + c3 int, + primary key (id), + unique key uk1 (c1), + unique key uk3 (c3), + unique key uk31_alias (c3, c1), + key k2 (c2) + )`, + expectIntroduced: []string{"uk3"}, + expectRemoved: []string{"uk2"}, + }, + { + sourceTable: ` + create table source_table ( + id int, + c1 int, + c2 int, + c3 int, + primary key (id), + unique key uk1 (c1), + unique key uk2 (c2), + unique key uk31 (c3, c1), + key k1 (c1) + )`, + targetTable: ` + create table source_table ( + id int, + c1 int, + c2 int, + c3 int, + primary key (id), + unique key uk1 (c1), + unique key uk3 (c3), + key k2 (c2) + )`, + expectIntroduced: []string{"uk3"}, // uk31 (c3, c1) not considered removed because the new "uk3" is even more constrained + expectRemoved: []string{"uk2"}, + }, + { + sourceTable: ` + create table source_table ( + id int, + c1 int, + c2 int, + v varchar(128), + primary key (id), + unique key uk12 (c1, c2), + unique key ukv5 (v(5)), + key k1 (c1) + )`, + targetTable: ` + create table source_table ( + id int, + c1 int, + c2 int, + c3 int, + v varchar(128), + primary key (id), + unique key uk1v2 (c1, v(2)), + unique key uk1v7 (c1, v(7)), + unique key ukv3 (v(3)), + key k2 (c2) + )`, + expectIntroduced: []string{"uk1v2", "ukv3"}, + expectRemoved: []string{"uk12"}, + }, + } + for _, tcase := range tcases { + t.Run("", func(t *testing.T) { + sourceEntity, err := NewCreateTableEntityFromSQL(env, tcase.sourceTable) + require.NoError(t, err) + err = sourceEntity.validate() + require.NoError(t, err) + sourceUniqueKeys := PrioritizedUniqueKeys(sourceEntity) + + targetEntity, err := NewCreateTableEntityFromSQL(env, tcase.targetTable) + require.NoError(t, err) + err = targetEntity.validate() + require.NoError(t, err) + targetUniqueKeys := PrioritizedUniqueKeys(targetEntity) + + introduced := IntroducedUniqueConstraints(sourceUniqueKeys, targetUniqueKeys, nil) + assert.Equal(t, tcase.expectIntroduced, introduced.Names()) + }) + } +} + +func TestUniqueKeysCoveredByColumns(t *testing.T) { + env := NewTestEnv() + table := ` + create table t ( + id int, + c1 int not null, + c2 int not null, + c3 int not null, + c9 int, + v varchar(32) not null, + primary key (id), + unique key uk1 (c1), + unique key uk3 (c3), + unique key uk9 (c9), + key k3 (c3), + unique key uk12 (c1, c2), + unique key uk13 (c1, c3), + unique key uk23 (c2, c3), + unique key uk123 (c1, c2, c3), + unique key uk21 (c2, c1), + unique key ukv (v), + unique key ukv3 (v(3)), + unique key uk2v5 (c2, v(5)), + unique key uk3v (c3, v) + ) + ` + tcases := []struct { + columns []string + expect []string + }{ + { + columns: []string{"id"}, + expect: []string{"PRIMARY"}, + }, + { + columns: []string{"c1"}, + expect: []string{"uk1"}, + }, + { + columns: []string{"id", "c1"}, + expect: []string{"PRIMARY", "uk1"}, + }, + { + columns: []string{"c1", "id"}, + expect: []string{"PRIMARY", "uk1"}, + }, + { + columns: []string{"c9"}, + expect: []string{}, // nullable column + }, + { + columns: []string{"v"}, + expect: []string{"ukv"}, + }, + { + columns: []string{"v", "c9"}, + expect: []string{"ukv"}, + }, + { + columns: []string{"v", "c2"}, + expect: []string{"ukv"}, + }, + { + columns: []string{"v", "c2", "c3"}, + expect: []string{"uk3", "uk23", "uk3v", "ukv"}, + }, + { + columns: []string{"id", "c1", "c2", "c3", "v"}, + expect: []string{"PRIMARY", "uk1", "uk3", "uk12", "uk13", "uk23", "uk21", "uk3v", "uk123", "ukv"}, + }, + } + + entity, err := NewCreateTableEntityFromSQL(env, table) + require.NoError(t, err) + err = entity.validate() + require.NoError(t, err) + tableColumns := entity.ColumnDefinitionEntitiesList() + tableKeys := PrioritizedUniqueKeys(entity) + assert.Equal(t, []string{ + "PRIMARY", + "uk1", + "uk3", + "uk12", + "uk13", + "uk23", + "uk21", + "uk3v", + "uk123", + "ukv", + "uk2v5", + "ukv3", + "uk9", + }, tableKeys.Names()) + + for _, tcase := range tcases { + t.Run(strings.Join(tcase.columns, ","), func(t *testing.T) { + columns := []*ColumnDefinitionEntity{} + for _, tcaseCol := range tcase.columns { + col := tableColumns.GetColumn(tcaseCol) + require.NotNil(t, col) + columns = append(columns, col) + } + columnsList := NewColumnDefinitionEntityList(columns) + + covered := IterationKeysByColumns(tableKeys, columnsList) + assert.Equal(t, tcase.expect, covered.Names()) + }) + } +} + +func TestRevertible(t *testing.T) { + + type revertibleTestCase struct { + name string + fromSchema string + toSchema string + // expectProblems bool + removedForeignKeyNames string + removedUniqueKeyNames string + droppedNoDefaultColumnNames string + expandedColumnNames string + } + + var testCases = []revertibleTestCase{ + { + name: "identical schemas", + fromSchema: `id int primary key, i1 int not null default 0`, + toSchema: `id int primary key, i2 int not null default 0`, + }, + { + name: "different schemas, nothing to note", + fromSchema: `id int primary key, i1 int not null default 0, unique key i1_uidx(i1)`, + toSchema: `id int primary key, i1 int not null default 0, i2 int not null default 0, unique key i1_uidx(i1)`, + }, + { + name: "removed non-nullable unique key", + fromSchema: `id int primary key, i1 int not null default 0, unique key i1_uidx(i1)`, + toSchema: `id int primary key, i2 int not null default 0`, + removedUniqueKeyNames: `i1_uidx`, + }, + { + name: "removed nullable unique key", + fromSchema: `id int primary key, i1 int default null, unique key i1_uidx(i1)`, + toSchema: `id int primary key, i2 int default null`, + removedUniqueKeyNames: `i1_uidx`, + }, + { + name: "expanding unique key removes unique constraint", + fromSchema: `id int primary key, i1 int default null, unique key i1_uidx(i1)`, + toSchema: `id int primary key, i1 int default null, unique key i1_uidx(i1, id)`, + removedUniqueKeyNames: `i1_uidx`, + }, + { + name: "expanding unique key prefix removes unique constraint", + fromSchema: `id int primary key, v varchar(100) default null, unique key v_uidx(v(20))`, + toSchema: `id int primary key, v varchar(100) default null, unique key v_uidx(v(21))`, + removedUniqueKeyNames: `v_uidx`, + }, + { + name: "reducing unique key does not remove unique constraint", + fromSchema: `id int primary key, i1 int default null, unique key i1_uidx(i1, id)`, + toSchema: `id int primary key, i1 int default null, unique key i1_uidx(i1)`, + removedUniqueKeyNames: ``, + }, + { + name: "reducing unique key does not remove unique constraint", + fromSchema: `id int primary key, v varchar(100) default null, unique key v_uidx(v(21))`, + toSchema: `id int primary key, v varchar(100) default null, unique key v_uidx(v(20))`, + }, + { + name: "removed foreign key", + fromSchema: "id int primary key, i int, constraint some_fk_1 foreign key (i) references parent (id) on delete cascade", + toSchema: "id int primary key, i int", + removedForeignKeyNames: "some_fk_1", + }, + + { + name: "renamed foreign key", + fromSchema: "id int primary key, i int, constraint f1 foreign key (i) references parent (id) on delete cascade", + toSchema: "id int primary key, i int, constraint f2 foreign key (i) references parent (id) on delete cascade", + }, + { + name: "remove column without default", + fromSchema: `id int primary key, i1 int not null, i2 int not null default 0, i3 int default null`, + toSchema: `id int primary key, i4 int not null default 0`, + droppedNoDefaultColumnNames: `i1`, + }, + { + name: "expanded: nullable", + fromSchema: `id int primary key, i1 int not null, i2 int default null`, + toSchema: `id int primary key, i1 int default null, i2 int not null`, + expandedColumnNames: `i1`, + }, + { + name: "expanded: longer text", + fromSchema: `id int primary key, i1 int default null, v1 varchar(40) not null, v2 varchar(5), v3 varchar(3)`, + toSchema: `id int primary key, i1 int not null, v1 varchar(100) not null, v2 char(3), v3 char(5)`, + expandedColumnNames: `v1,v3`, + }, + { + name: "expanded: int numeric precision and scale", + fromSchema: `id int primary key, i1 int, i2 tinyint, i3 mediumint, i4 bigint`, + toSchema: `id int primary key, i1 int, i2 mediumint, i3 int, i4 tinyint`, + expandedColumnNames: `i2,i3`, + }, + { + name: "expanded: floating point", + fromSchema: `id int primary key, i1 int, n2 bigint, n3 bigint, n4 float, n5 double`, + toSchema: `id int primary key, i1 int, n2 float, n3 double, n4 double, n5 float`, + expandedColumnNames: `n2,n3,n4`, + }, + { + name: "expanded: decimal numeric precision and scale", + fromSchema: `id int primary key, i1 int, d1 decimal(10,2), d2 decimal (10,2), d3 decimal (10,2)`, + toSchema: `id int primary key, i1 int, d1 decimal(11,2), d2 decimal (9,1), d3 decimal (10,3)`, + expandedColumnNames: `d1,d3`, + }, + { + name: "expanded: signed, unsigned", + fromSchema: `id int primary key, i1 bigint signed, i2 int unsigned, i3 bigint unsigned`, + toSchema: `id int primary key, i1 int signed, i2 int signed, i3 int signed`, + expandedColumnNames: `i2,i3`, + }, + { + name: "expanded: signed, unsigned: range", + fromSchema: `id int primary key, i1 int signed, i2 bigint signed, i3 int signed`, + toSchema: `id int primary key, i1 int unsigned, i2 int unsigned, i3 bigint unsigned`, + expandedColumnNames: `i1,i3`, + }, + { + name: "expanded: datetime precision", + fromSchema: `id int primary key, dt1 datetime, ts1 timestamp, ti1 time, dt2 datetime(3), dt3 datetime(6), ts2 timestamp(3)`, + toSchema: `id int primary key, dt1 datetime(3), ts1 timestamp(6), ti1 time(3), dt2 datetime(6), dt3 datetime(3), ts2 timestamp`, + expandedColumnNames: `dt1,ts1,ti1,dt2`, + }, + { + name: "expanded: strange data type changes", + fromSchema: `id int primary key, dt1 datetime, ts1 timestamp, i1 int, d1 date, e1 enum('a', 'b')`, + toSchema: `id int primary key, dt1 char(32), ts1 varchar(32), i1 tinytext, d1 char(2), e1 varchar(2)`, + expandedColumnNames: `dt1,ts1,i1,d1,e1`, + }, + { + name: "expanded: temporal types", + fromSchema: `id int primary key, t1 time, t2 timestamp, t3 date, t4 datetime, t5 time, t6 date`, + toSchema: `id int primary key, t1 datetime, t2 datetime, t3 timestamp, t4 timestamp, t5 timestamp, t6 datetime`, + expandedColumnNames: `t1,t2,t3,t5,t6`, + }, + { + name: "expanded: character sets", + fromSchema: `id int primary key, c1 char(3) charset utf8, c2 char(3) charset utf8mb4, c3 char(3) charset ascii, c4 char(3) charset utf8mb4, c5 char(3) charset utf8, c6 char(3) charset latin1`, + toSchema: `id int primary key, c1 char(3) charset utf8mb4, c2 char(3) charset utf8, c3 char(3) charset utf8, c4 char(3) charset ascii, c5 char(3) charset utf8, c6 char(3) charset utf8mb4`, + expandedColumnNames: `c1,c3,c6`, + }, + { + name: "expanded: enum", + fromSchema: `id int primary key, e1 enum('a', 'b'), e2 enum('a', 'b'), e3 enum('a', 'b'), e4 enum('a', 'b'), e5 enum('a', 'b'), e6 enum('a', 'b'), e7 enum('a', 'b'), e8 enum('a', 'b')`, + toSchema: `id int primary key, e1 enum('a', 'b'), e2 enum('a'), e3 enum('a', 'b', 'c'), e4 enum('a', 'x'), e5 enum('a', 'x', 'b'), e6 enum('b'), e7 varchar(1), e8 tinyint`, + expandedColumnNames: `e3,e4,e5,e6,e7,e8`, + }, + { + name: "expanded: set", + fromSchema: `id int primary key, e1 set('a', 'b'), e2 set('a', 'b'), e3 set('a', 'b'), e4 set('a', 'b'), e5 set('a', 'b'), e6 set('a', 'b'), e7 set('a', 'b'), e8 set('a', 'b')`, + toSchema: `id int primary key, e1 set('a', 'b'), e2 set('a'), e3 set('a', 'b', 'c'), e4 set('a', 'x'), e5 set('a', 'x', 'b'), e6 set('b'), e7 varchar(1), e8 tinyint`, + expandedColumnNames: `e3,e4,e5,e6,e7,e8`, + }, + } + + var ( + createTableWrapper = `CREATE TABLE t (%s)` + ) + + env := NewTestEnv() + diffHints := &DiffHints{} + for _, tcase := range testCases { + t.Run(tcase.name, func(t *testing.T) { + tcase.fromSchema = fmt.Sprintf(createTableWrapper, tcase.fromSchema) + sourceTableEntity, err := NewCreateTableEntityFromSQL(env, tcase.fromSchema) + require.NoError(t, err) + + tcase.toSchema = fmt.Sprintf(createTableWrapper, tcase.toSchema) + targetTableEntity, err := NewCreateTableEntityFromSQL(env, tcase.toSchema) + require.NoError(t, err) + + diff, err := sourceTableEntity.TableDiff(targetTableEntity, diffHints) + require.NoError(t, err) + alterTableAnalysis := OnlineDDLAlterTableAnalysis(diff.AlterTable()) + + analysis, err := OnlineDDLMigrationTablesAnalysis(sourceTableEntity, targetTableEntity, alterTableAnalysis) + require.NoError(t, err) + + toStringSlice := func(s string) []string { + if s == "" { + return []string{} + } + return strings.Split(s, ",") + } + assert.Equal(t, toStringSlice(tcase.removedForeignKeyNames), analysis.RemovedForeignKeyNames) + assert.Equal(t, toStringSlice(tcase.removedUniqueKeyNames), analysis.RemovedUniqueKeys.Names()) + assert.Equal(t, toStringSlice(tcase.droppedNoDefaultColumnNames), analysis.DroppedNoDefaultColumns.Names()) + assert.Equal(t, toStringSlice(tcase.expandedColumnNames), analysis.ExpandedColumns.Names()) + }) + } +} diff --git a/go/vt/schemadiff/schema_diff_test.go b/go/vt/schemadiff/schema_diff_test.go index 10ad260100b..8088cc896ed 100644 --- a/go/vt/schemadiff/schema_diff_test.go +++ b/go/vt/schemadiff/schema_diff_test.go @@ -1321,7 +1321,6 @@ func TestSchemaDiff(t *testing.T) { instantCapability := schemaDiff.InstantDDLCapability() assert.Equal(t, tc.instantCapability, instantCapability, "for instant capability") }) - } } diff --git a/go/vt/schemadiff/table.go b/go/vt/schemadiff/table.go index 5629210b6c1..c326b2763b3 100644 --- a/go/vt/schemadiff/table.go +++ b/go/vt/schemadiff/table.go @@ -445,6 +445,18 @@ type CreateTableEntity struct { Env *Environment } +func NewCreateTableEntityFromSQL(env *Environment, sql string) (*CreateTableEntity, error) { + stmt, err := env.Parser().ParseStrictDDL(sql) + if err != nil { + return nil, err + } + createTable, ok := stmt.(*sqlparser.CreateTable) + if !ok { + return nil, ErrExpectedCreateTable + } + return NewCreateTableEntity(env, createTable) +} + func NewCreateTableEntity(env *Environment, c *sqlparser.CreateTable) (*CreateTableEntity, error) { if !c.IsFullyParsed() { return nil, &NotFullyParsedError{Entity: c.Table.Name.String(), Statement: sqlparser.CanonicalString(c)} @@ -454,15 +466,64 @@ func NewCreateTableEntity(env *Environment, c *sqlparser.CreateTable) (*CreateTa return entity, nil } +// ColumnDefinitionEntities returns the list of column entities for the table. func (c *CreateTableEntity) ColumnDefinitionEntities() []*ColumnDefinitionEntity { cc := getTableCharsetCollate(c.Env, &c.CreateTable.TableSpec.Options) + pkColumnsMaps := c.primaryKeyColumnsMap() entities := make([]*ColumnDefinitionEntity, len(c.CreateTable.TableSpec.Columns)) for i := range c.CreateTable.TableSpec.Columns { - entities[i] = NewColumnDefinitionEntity(c.Env, c.CreateTable.TableSpec.Columns[i], cc) + col := c.CreateTable.TableSpec.Columns[i] + _, inPK := pkColumnsMaps[col.Name.Lowered()] + entities[i] = NewColumnDefinitionEntity(c.Env, col, inPK, cc) } return entities } +// ColumnDefinitionEntities returns the list of column entities for the table. +func (c *CreateTableEntity) ColumnDefinitionEntitiesList() *ColumnDefinitionEntityList { + return NewColumnDefinitionEntityList(c.ColumnDefinitionEntities()) +} + +// ColumnDefinitionEntities returns column entities mapped by their lower cased name +func (c *CreateTableEntity) ColumnDefinitionEntitiesMap() map[string]*ColumnDefinitionEntity { + entities := c.ColumnDefinitionEntities() + m := make(map[string]*ColumnDefinitionEntity, len(entities)) + for _, entity := range entities { + m[entity.NameLowered()] = entity + } + return m +} + +// IndexDefinitionEntities returns the list of index entities for the table. +func (c *CreateTableEntity) IndexDefinitionEntities() []*IndexDefinitionEntity { + colMap := c.ColumnDefinitionEntitiesMap() + keys := c.CreateTable.TableSpec.Indexes + entities := make([]*IndexDefinitionEntity, len(keys)) + for i, key := range keys { + colEntities := make([]*ColumnDefinitionEntity, len(key.Columns)) + for i, keyCol := range key.Columns { + colEntities[i] = colMap[keyCol.Column.Lowered()] + } + entities[i] = NewIndexDefinitionEntity(c.Env, key, NewColumnDefinitionEntityList(colEntities)) + } + return entities +} + +// IndexDefinitionEntityList returns the list of index entities for the table. +func (c *CreateTableEntity) IndexDefinitionEntitiesList() *IndexDefinitionEntityList { + return NewIndexDefinitionEntityList(c.IndexDefinitionEntities()) +} + +// IndexDefinitionEntitiesMap returns index entities mapped by their lower cased name. +func (c *CreateTableEntity) IndexDefinitionEntitiesMap() map[string]*IndexDefinitionEntity { + entities := c.IndexDefinitionEntities() + m := make(map[string]*IndexDefinitionEntity, len(entities)) + for _, entity := range entities { + m[entity.NameLowered()] = entity + } + return m +} + // normalize cleans up the table definition: // - setting names to all keys // - table option case (upper/lower/special) @@ -1740,8 +1801,8 @@ func (c *CreateTableEntity) diffColumns(alterTable *sqlparser.AlterTable, t2ColName := t2Col.Name.Lowered() // we know that column exists in both tables t1Col := t1ColumnsMap[t2ColName] - t1ColEntity := NewColumnDefinitionEntity(c.Env, t1Col.col, t1cc) - t2ColEntity := NewColumnDefinitionEntity(c.Env, t2Col, t2cc) + t1ColEntity := NewColumnDefinitionEntity(c.Env, t1Col.col, false, t1cc) + t2ColEntity := NewColumnDefinitionEntity(c.Env, t2Col, false, t2cc) // check diff between before/after columns: modifyColumnDiff, err := t1ColEntity.ColumnDiff(c.Env, c.Name(), t2ColEntity, hints) @@ -1892,6 +1953,15 @@ func (c *CreateTableEntity) primaryKeyColumns() []*sqlparser.IndexColumn { return nil } +func (c *CreateTableEntity) primaryKeyColumnsMap() map[string]*sqlparser.IndexColumn { + columns := c.primaryKeyColumns() + m := make(map[string]*sqlparser.IndexColumn, len(columns)) + for _, col := range columns { + m[col.Column.Lowered()] = col + } + return m +} + // Create implements Entity interface func (c *CreateTableEntity) Create() EntityDiff { if c == nil { @@ -2648,3 +2718,18 @@ func (c *CreateTableEntity) identicalOtherThanName(other *CreateTableEntity) boo return sqlparser.Equals.RefOfTableSpec(c.TableSpec, other.TableSpec) && sqlparser.Equals.RefOfParsedComments(c.Comments, other.Comments) } + +// AutoIncrementValue returns the value of the AUTO_INCREMENT option, or zero if not exists. +func (c *CreateTableEntity) AutoIncrementValue() (autoIncrement uint64, err error) { + for _, option := range c.CreateTable.TableSpec.Options { + if strings.ToUpper(option.Name) == "AUTO_INCREMENT" { + autoIncrement, err := strconv.ParseUint(option.Value.Val, 10, 64) + if err != nil { + return 0, err + } + return autoIncrement, nil + } + } + // Auto increment not found + return 0, nil +} diff --git a/go/vt/schemadiff/table_test.go b/go/vt/schemadiff/table_test.go index 1168f53f3b6..389e55f447c 100644 --- a/go/vt/schemadiff/table_test.go +++ b/go/vt/schemadiff/table_test.go @@ -2781,9 +2781,10 @@ func TestValidate(t *testing.T) { func TestNormalize(t *testing.T) { tt := []struct { - name string - from string - to string + name string + from string + to string + autoinc uint64 }{ { name: "basic table", @@ -2795,6 +2796,17 @@ func TestNormalize(t *testing.T) { from: "create table t (id int primary key, i int)", to: "CREATE TABLE `t` (\n\t`id` int,\n\t`i` int,\n\tPRIMARY KEY (`id`)\n)", }, + { + name: "basic table, auto increment", + from: "create table t (id int auto_increment primary key, i int)", + to: "CREATE TABLE `t` (\n\t`id` int AUTO_INCREMENT,\n\t`i` int,\n\tPRIMARY KEY (`id`)\n)", + }, + { + name: "basic table, auto increment val", + from: "create table t (id int auto_increment primary key, i int) auto_increment = 123", + to: "CREATE TABLE `t` (\n\t`id` int AUTO_INCREMENT,\n\t`i` int,\n\tPRIMARY KEY (`id`)\n) AUTO_INCREMENT 123", + autoinc: 123, + }, { name: "removes default null", from: "create table t (id int, i int default null, primary key (id))", @@ -3067,6 +3079,10 @@ func TestNormalize(t *testing.T) { from, err := NewCreateTableEntity(env, fromCreateTable) require.NoError(t, err) assert.Equal(t, ts.to, sqlparser.CanonicalString(from)) + + autoinc, err := from.AutoIncrementValue() + require.NoError(t, err) + assert.EqualValues(t, ts.autoinc, autoinc) }) } } diff --git a/go/vt/schemadiff/view.go b/go/vt/schemadiff/view.go index d2dc4dfb76f..8783f1803bb 100644 --- a/go/vt/schemadiff/view.go +++ b/go/vt/schemadiff/view.go @@ -310,6 +310,18 @@ func NewCreateViewEntity(env *Environment, c *sqlparser.CreateView) (*CreateView return entity, nil } +func NewCreateViewEntityFromSQL(env *Environment, sql string) (*CreateViewEntity, error) { + stmt, err := env.Parser().ParseStrictDDL(sql) + if err != nil { + return nil, err + } + createView, ok := stmt.(*sqlparser.CreateView) + if !ok { + return nil, ErrExpectedCreateTable + } + return NewCreateViewEntity(env, createView) +} + func (c *CreateViewEntity) normalize() { // Drop the default algorithm if strings.EqualFold(c.CreateView.Algorithm, "undefined") { diff --git a/go/vt/schemadiff/view_test.go b/go/vt/schemadiff/view_test.go index d1a26c3cdaa..d020649b17e 100644 --- a/go/vt/schemadiff/view_test.go +++ b/go/vt/schemadiff/view_test.go @@ -150,19 +150,16 @@ func TestCreateViewDiff(t *testing.T) { for _, ts := range tt { t.Run(ts.name, func(t *testing.T) { fromStmt, err := env.Parser().ParseStrictDDL(ts.from) - assert.NoError(t, err) + require.NoError(t, err) fromCreateView, ok := fromStmt.(*sqlparser.CreateView) assert.True(t, ok) - toStmt, err := env.Parser().ParseStrictDDL(ts.to) - assert.NoError(t, err) - toCreateView, ok := toStmt.(*sqlparser.CreateView) - assert.True(t, ok) - c, err := NewCreateViewEntity(env, fromCreateView) require.NoError(t, err) - other, err := NewCreateViewEntity(env, toCreateView) + // Test from SQL: + other, err := NewCreateViewEntityFromSQL(env, ts.to) require.NoError(t, err) + alter, err := c.Diff(other, hints) switch { case ts.isError: diff --git a/go/vt/vttablet/onlineddl/executor.go b/go/vt/vttablet/onlineddl/executor.go index 757caa711b7..0d43d52d7f4 100644 --- a/go/vt/vttablet/onlineddl/executor.go +++ b/go/vt/vttablet/onlineddl/executor.go @@ -616,7 +616,7 @@ func (e *Executor) getCreateTableStatement(ctx context.Context, tableName string } createTable, ok := stmt.(*sqlparser.CreateTable) if !ok { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "expected CREATE TABLE. Got %v", sqlparser.CanonicalString(stmt)) + return nil, schemadiff.ErrExpectedCreateTable } return createTable, nil } @@ -1518,7 +1518,10 @@ func (e *Executor) initVreplicationOriginalMigration(ctx context.Context, online return v, err } - v = NewVRepl(e.env.Environment(), onlineDDL.UUID, e.keyspace, e.shard, e.dbName, onlineDDL.Table, vreplTableName, originalCreateTable, vreplCreateTable, alterTable, onlineDDL.StrategySetting().IsAnalyzeTableFlag()) + v, err = NewVRepl(e.env.Environment(), onlineDDL.UUID, e.keyspace, e.shard, e.dbName, originalCreateTable, vreplCreateTable, alterTable, onlineDDL.StrategySetting().IsAnalyzeTableFlag()) + if err != nil { + return v, err + } return v, nil } @@ -1526,7 +1529,7 @@ func (e *Executor) initVreplicationOriginalMigration(ctx context.Context, online // This function is called after both source and target tables have been analyzed, so there's more information // about the two, and about the transition between the two. func (e *Executor) postInitVreplicationOriginalMigration(ctx context.Context, onlineDDL *schema.OnlineDDL, v *VRepl, conn *dbconnpool.DBConnection) (err error) { - if v.sourceAutoIncrement > 0 && !v.parser.IsAutoIncrementDefined() { + if v.analysis.SourceAutoIncrement > 0 && !v.alterTableAnalysis.IsAutoIncrementChangeRequested { restoreSQLModeFunc, err := e.initMigrationSQLMode(ctx, onlineDDL, conn) defer restoreSQLModeFunc() if err != nil { @@ -1534,9 +1537,9 @@ func (e *Executor) postInitVreplicationOriginalMigration(ctx context.Context, on } // Apply ALTER TABLE AUTO_INCREMENT=? - parsed := sqlparser.BuildParsedQuery(sqlAlterTableAutoIncrement, v.targetTable, ":auto_increment") + parsed := sqlparser.BuildParsedQuery(sqlAlterTableAutoIncrement, v.targetTableName(), ":auto_increment") bindVars := map[string]*querypb.BindVariable{ - "auto_increment": sqltypes.Uint64BindVariable(v.sourceAutoIncrement), + "auto_increment": sqltypes.Uint64BindVariable(v.analysis.SourceAutoIncrement), } bound, err := parsed.GenerateQuery(bindVars, nil) if err != nil { @@ -1572,7 +1575,18 @@ func (e *Executor) initVreplicationRevertMigration(ctx context.Context, onlineDD if err := e.updateArtifacts(ctx, onlineDDL.UUID, vreplTableName); err != nil { return v, err } - v = NewVRepl(e.env.Environment(), onlineDDL.UUID, e.keyspace, e.shard, e.dbName, onlineDDL.Table, vreplTableName, nil, nil, nil, false) + originalCreateTable, err := e.getCreateTableStatement(ctx, onlineDDL.Table) + if err != nil { + return v, err + } + vreplCreateTable, err := e.getCreateTableStatement(ctx, vreplTableName) + if err != nil { + return v, err + } + v, err = NewVRepl(e.env.Environment(), onlineDDL.UUID, e.keyspace, e.shard, e.dbName, originalCreateTable, vreplCreateTable, nil, false) + if err != nil { + return v, err + } v.pos = revertStream.pos return v, nil } @@ -1614,19 +1628,15 @@ func (e *Executor) ExecuteWithVReplication(ctx context.Context, onlineDDL *schem if err := e.updateMigrationTableRows(ctx, onlineDDL.UUID, v.tableRows); err != nil { return err } - removedUniqueKeyNames := []string{} - for _, uniqueKey := range v.removedUniqueKeys { - removedUniqueKeyNames = append(removedUniqueKeyNames, uniqueKey.Name) - } if err := e.updateSchemaAnalysis(ctx, onlineDDL.UUID, - len(v.addedUniqueKeys), - len(v.removedUniqueKeys), - strings.Join(sqlescape.EscapeIDs(removedUniqueKeyNames), ","), - strings.Join(sqlescape.EscapeIDs(v.removedForeignKeyNames), ","), - strings.Join(sqlescape.EscapeIDs(v.droppedNoDefaultColumnNames), ","), - strings.Join(sqlescape.EscapeIDs(v.expandedColumnNames), ","), - v.revertibleNotes, + v.analysis.AddedUniqueKeys.Len(), + v.analysis.RemovedUniqueKeys.Len(), + strings.Join(sqlescape.EscapeIDs(v.analysis.RemovedUniqueKeys.Names()), ","), + strings.Join(sqlescape.EscapeIDs(v.analysis.RemovedForeignKeyNames), ","), + strings.Join(sqlescape.EscapeIDs(v.analysis.DroppedNoDefaultColumns.Names()), ","), + strings.Join(sqlescape.EscapeIDs(v.analysis.ExpandedColumns.Names()), ","), + v.analysis.RevertibleNotes, ); err != nil { return err } @@ -1654,7 +1664,7 @@ func (e *Executor) ExecuteWithVReplication(ctx context.Context, onlineDDL *schem } // create vreplication entry - insertVReplicationQuery, err := v.generateInsertStatement(ctx) + insertVReplicationQuery, err := v.generateInsertStatement() if err != nil { return err } @@ -1671,7 +1681,7 @@ func (e *Executor) ExecuteWithVReplication(ctx context.Context, onlineDDL *schem } } // start stream! - startVReplicationQuery, err := v.generateStartStatement(ctx) + startVReplicationQuery, err := v.generateStartStatement() if err != nil { return err } @@ -2967,7 +2977,7 @@ func (e *Executor) analyzeDropDDLActionMigration(ctx context.Context, onlineDDL // Write analysis: } if err := e.updateSchemaAnalysis(ctx, onlineDDL.UUID, - 0, 0, "", strings.Join(sqlescape.EscapeIDs(removedForeignKeyNames), ","), "", "", "", + 0, 0, "", strings.Join(sqlescape.EscapeIDs(removedForeignKeyNames), ","), "", "", nil, ); err != nil { return err } @@ -4492,7 +4502,8 @@ func (e *Executor) updateSchemaAnalysis(ctx context.Context, uuid string, addedUniqueKeys, removedUniqueKeys int, removedUniqueKeyNames string, removedForeignKeyNames string, droppedNoDefaultColumnNames string, expandedColumnNames string, - revertibleNotes string) error { + revertibleNotes []string) error { + notes := strings.Join(revertibleNotes, "\n") query, err := sqlparser.ParseAndBind(sqlUpdateSchemaAnalysis, sqltypes.Int64BindVariable(int64(addedUniqueKeys)), sqltypes.Int64BindVariable(int64(removedUniqueKeys)), @@ -4500,7 +4511,7 @@ func (e *Executor) updateSchemaAnalysis(ctx context.Context, uuid string, sqltypes.StringBindVariable(removedForeignKeyNames), sqltypes.StringBindVariable(droppedNoDefaultColumnNames), sqltypes.StringBindVariable(expandedColumnNames), - sqltypes.StringBindVariable(revertibleNotes), + sqltypes.StringBindVariable(notes), sqltypes.StringBindVariable(uuid), ) if err != nil { diff --git a/go/vt/vttablet/onlineddl/schema.go b/go/vt/vttablet/onlineddl/schema.go index 4f65864cbfa..28e32e7dab4 100644 --- a/go/vt/vttablet/onlineddl/schema.go +++ b/go/vt/vttablet/onlineddl/schema.go @@ -461,16 +461,6 @@ const ( AND ACTION_TIMING='AFTER' AND LEFT(TRIGGER_NAME, 7)='pt_osc_' ` - sqlSelectColumnTypes = ` - select - *, - COLUMN_DEFAULT IS NULL AS is_default_null - from - information_schema.columns - where - table_schema=%a - and table_name=%a - ` selSelectCountFKParentConstraints = ` SELECT COUNT(*) as num_fk_constraints @@ -487,75 +477,10 @@ const ( TABLE_SCHEMA=%a AND TABLE_NAME=%a AND REFERENCED_TABLE_NAME IS NOT NULL ` - sqlSelectUniqueKeys = ` - SELECT - COLUMNS.TABLE_SCHEMA as table_schema, - COLUMNS.TABLE_NAME as table_name, - COLUMNS.COLUMN_NAME as column_name, - UNIQUES.INDEX_NAME as index_name, - UNIQUES.COLUMN_NAMES as column_names, - UNIQUES.COUNT_COLUMN_IN_INDEX as count_column_in_index, - COLUMNS.DATA_TYPE as data_type, - COLUMNS.CHARACTER_SET_NAME as character_set_name, - LOCATE('auto_increment', EXTRA) > 0 as is_auto_increment, - (DATA_TYPE='float' OR DATA_TYPE='double') AS is_float, - has_subpart, - has_nullable - FROM INFORMATION_SCHEMA.COLUMNS INNER JOIN ( - SELECT - TABLE_SCHEMA, - TABLE_NAME, - INDEX_NAME, - COUNT(*) AS COUNT_COLUMN_IN_INDEX, - GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX ASC) AS COLUMN_NAMES, - SUBSTRING_INDEX(GROUP_CONCAT(COLUMN_NAME ORDER BY SEQ_IN_INDEX ASC), ',', 1) AS FIRST_COLUMN_NAME, - SUM(SUB_PART IS NOT NULL) > 0 AS has_subpart, - SUM(NULLABLE='YES') > 0 AS has_nullable - FROM INFORMATION_SCHEMA.STATISTICS - WHERE - NON_UNIQUE=0 - AND TABLE_SCHEMA=%a - AND TABLE_NAME=%a - GROUP BY TABLE_SCHEMA, TABLE_NAME, INDEX_NAME - ) AS UNIQUES - ON ( - COLUMNS.COLUMN_NAME = UNIQUES.FIRST_COLUMN_NAME - ) - WHERE - COLUMNS.TABLE_SCHEMA=%a - AND COLUMNS.TABLE_NAME=%a - ORDER BY - COLUMNS.TABLE_SCHEMA, COLUMNS.TABLE_NAME, - CASE UNIQUES.INDEX_NAME - WHEN 'PRIMARY' THEN 0 - ELSE 1 - END, - CASE has_nullable - WHEN 0 THEN 0 - ELSE 1 - END, - CASE has_subpart - WHEN 0 THEN 0 - ELSE 1 - END, - CASE IFNULL(CHARACTER_SET_NAME, '') - WHEN '' THEN 0 - ELSE 1 - END, - CASE DATA_TYPE - WHEN 'tinyint' THEN 0 - WHEN 'smallint' THEN 1 - WHEN 'int' THEN 2 - WHEN 'bigint' THEN 3 - ELSE 100 - END, - COUNT_COLUMN_IN_INDEX - ` sqlDropTrigger = "DROP TRIGGER IF EXISTS `%a`.`%a`" sqlShowTablesLike = "SHOW TABLES LIKE '%a'" sqlDropTable = "DROP TABLE `%a`" sqlDropTableIfExists = "DROP TABLE IF EXISTS `%a`" - sqlShowColumnsFrom = "SHOW COLUMNS FROM `%a`" sqlShowTableStatus = "SHOW TABLE STATUS LIKE '%a'" sqlAnalyzeTable = "ANALYZE NO_WRITE_TO_BINLOG TABLE `%a`" sqlShowCreateTable = "SHOW CREATE TABLE `%a`" @@ -563,23 +488,14 @@ const ( sqlShowVariablesLikeFastAnalyzeTable = "show global variables like 'fast_analyze_table'" sqlEnableFastAnalyzeTable = "set @@fast_analyze_table = 1" sqlDisableFastAnalyzeTable = "set @@fast_analyze_table = 0" - sqlGetAutoIncrement = ` - SELECT - AUTO_INCREMENT - FROM INFORMATION_SCHEMA.TABLES - WHERE - TABLES.TABLE_SCHEMA=%a - AND TABLES.TABLE_NAME=%a - AND AUTO_INCREMENT IS NOT NULL - ` - sqlAlterTableAutoIncrement = "ALTER TABLE `%s` AUTO_INCREMENT=%a" - sqlAlterTableExchangePartition = "ALTER TABLE `%a` EXCHANGE PARTITION `%a` WITH TABLE `%a`" - sqlAlterTableRemovePartitioning = "ALTER TABLE `%a` REMOVE PARTITIONING" - sqlAlterTableDropPartition = "ALTER TABLE `%a` DROP PARTITION `%a`" - sqlStartVReplStream = "UPDATE _vt.vreplication set state='Running' where db_name=%a and workflow=%a" - sqlStopVReplStream = "UPDATE _vt.vreplication set state='Stopped' where db_name=%a and workflow=%a" - sqlDeleteVReplStream = "DELETE FROM _vt.vreplication where db_name=%a and workflow=%a" - sqlReadVReplStream = `SELECT + sqlAlterTableAutoIncrement = "ALTER TABLE `%s` AUTO_INCREMENT=%a" + sqlAlterTableExchangePartition = "ALTER TABLE `%a` EXCHANGE PARTITION `%a` WITH TABLE `%a`" + sqlAlterTableRemovePartitioning = "ALTER TABLE `%a` REMOVE PARTITIONING" + sqlAlterTableDropPartition = "ALTER TABLE `%a` DROP PARTITION `%a`" + sqlStartVReplStream = "UPDATE _vt.vreplication set state='Running' where db_name=%a and workflow=%a" + sqlStopVReplStream = "UPDATE _vt.vreplication set state='Stopped' where db_name=%a and workflow=%a" + sqlDeleteVReplStream = "DELETE FROM _vt.vreplication where db_name=%a and workflow=%a" + sqlReadVReplStream = `SELECT id, workflow, source, diff --git a/go/vt/vttablet/onlineddl/vrepl.go b/go/vt/vttablet/onlineddl/vrepl.go index 42fe33a855f..14c52d352bf 100644 --- a/go/vt/vttablet/onlineddl/vrepl.go +++ b/go/vt/vttablet/onlineddl/vrepl.go @@ -38,11 +38,10 @@ import ( "vitess.io/vitess/go/textutil" "vitess.io/vitess/go/vt/dbconnpool" "vitess.io/vitess/go/vt/log" - "vitess.io/vitess/go/vt/schema" + "vitess.io/vitess/go/vt/schemadiff" "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtenv" "vitess.io/vitess/go/vt/vterrors" - "vitess.io/vitess/go/vt/vttablet/onlineddl/vrepl" "vitess.io/vitess/go/vt/vttablet/tabletmanager/vreplication" binlogdatapb "vitess.io/vitess/go/vt/proto/binlogdata" @@ -99,43 +98,23 @@ func (v *VReplStream) hasError() (isTerminal bool, vreplError error) { // VRepl is an online DDL helper for VReplication based migrations (ddl_strategy="online") type VRepl struct { - workflow string - keyspace string - shard string - dbName string - sourceTable string - targetTable string - pos string - alterQuery *sqlparser.AlterTable - tableRows int64 - - originalCreateTable *sqlparser.CreateTable - vreplCreateTable *sqlparser.CreateTable + workflow string + keyspace string + shard string + dbName string + pos string + tableRows int64 - analyzeTable bool - - sourceSharedColumns *vrepl.ColumnList - targetSharedColumns *vrepl.ColumnList - droppedSourceNonGeneratedColumns *vrepl.ColumnList - droppedNoDefaultColumnNames []string - expandedColumnNames []string - sharedColumnsMap map[string]string - sourceAutoIncrement uint64 + sourceCreateTableEntity *schemadiff.CreateTableEntity + targetCreateTableEntity *schemadiff.CreateTableEntity + analysis *schemadiff.MigrationTablesAnalysis - chosenSourceUniqueKey *vrepl.UniqueKey - chosenTargetUniqueKey *vrepl.UniqueKey - - addedUniqueKeys []*vrepl.UniqueKey - removedUniqueKeys []*vrepl.UniqueKey - removedForeignKeyNames []string + analyzeTable bool - revertibleNotes string - filterQuery string - enumToTextMap map[string]string - intToEnumMap map[string]bool - bls *binlogdatapb.BinlogSource + filterQuery string + bls *binlogdatapb.BinlogSource - parser *vrepl.AlterTableParser + alterTableAnalysis *schemadiff.AlterTableAnalysis convertCharset map[string](*binlogdatapb.CharsetConversion) @@ -149,110 +128,40 @@ func NewVRepl( keyspace string, shard string, dbName string, - sourceTable string, - targetTable string, - originalCreateTable *sqlparser.CreateTable, - vreplCreateTable *sqlparser.CreateTable, + sourceCreateTable *sqlparser.CreateTable, + targetCreateTable *sqlparser.CreateTable, alterQuery *sqlparser.AlterTable, analyzeTable bool, -) *VRepl { - return &VRepl{ - env: env, - workflow: workflow, - keyspace: keyspace, - shard: shard, - dbName: dbName, - sourceTable: sourceTable, - targetTable: targetTable, - originalCreateTable: originalCreateTable, - vreplCreateTable: vreplCreateTable, - alterQuery: alterQuery, - analyzeTable: analyzeTable, - parser: vrepl.NewAlterTableParser(), - enumToTextMap: map[string]string{}, - intToEnumMap: map[string]bool{}, - convertCharset: map[string](*binlogdatapb.CharsetConversion){}, - } -} - -// readAutoIncrement reads the AUTO_INCREMENT value, if any, for a give ntable -func (v *VRepl) readAutoIncrement(ctx context.Context, conn *dbconnpool.DBConnection, tableName string) (autoIncrement uint64, err error) { - query, err := sqlparser.ParseAndBind(sqlGetAutoIncrement, - sqltypes.StringBindVariable(v.dbName), - sqltypes.StringBindVariable(tableName), - ) +) (*VRepl, error) { + v := &VRepl{ + env: env, + workflow: workflow, + keyspace: keyspace, + shard: shard, + dbName: dbName, + alterTableAnalysis: schemadiff.OnlineDDLAlterTableAnalysis(alterQuery), + analyzeTable: analyzeTable, + convertCharset: map[string](*binlogdatapb.CharsetConversion){}, + } + senv := schemadiff.NewEnv(v.env, v.env.CollationEnv().DefaultConnectionCharset()) + var err error + v.sourceCreateTableEntity, err = schemadiff.NewCreateTableEntity(senv, sourceCreateTable) if err != nil { - return 0, err + return nil, err } - - rs, err := conn.ExecuteFetch(query, -1, true) + v.targetCreateTableEntity, err = schemadiff.NewCreateTableEntity(senv, targetCreateTable) if err != nil { - return 0, err - } - for _, row := range rs.Named().Rows { - autoIncrement = row.AsUint64("AUTO_INCREMENT", 0) + return nil, err } - - return autoIncrement, nil + return v, nil } -// readTableColumns reads column list from given table -func (v *VRepl) readTableColumns(ctx context.Context, conn *dbconnpool.DBConnection, tableName string) (columns *vrepl.ColumnList, virtualColumns *vrepl.ColumnList, pkColumns *vrepl.ColumnList, err error) { - parsed := sqlparser.BuildParsedQuery(sqlShowColumnsFrom, tableName) - rs, err := conn.ExecuteFetch(parsed.Query, -1, true) - if err != nil { - return nil, nil, nil, err - } - columnNames := []string{} - virtualColumnNames := []string{} - pkColumnNames := []string{} - for _, row := range rs.Named().Rows { - columnName := row.AsString("Field", "") - columnNames = append(columnNames, columnName) - - extra := row.AsString("Extra", "") - if strings.Contains(extra, "STORED GENERATED") || strings.Contains(extra, "VIRTUAL GENERATED") { - virtualColumnNames = append(virtualColumnNames, columnName) - } - - key := row.AsString("Key", "") - if key == "PRI" { - pkColumnNames = append(pkColumnNames, columnName) - } - } - if len(columnNames) == 0 { - return nil, nil, nil, fmt.Errorf("Found 0 columns on `%s`", tableName) - } - return vrepl.NewColumnList(columnNames), vrepl.NewColumnList(virtualColumnNames), vrepl.NewColumnList(pkColumnNames), nil +func (v *VRepl) sourceTableName() string { + return v.sourceCreateTableEntity.Name() } -// readTableUniqueKeys reads all unique keys from a given table, by order of usefulness/performance: PRIMARY first, integers are better, non-null are better -func (v *VRepl) readTableUniqueKeys(ctx context.Context, conn *dbconnpool.DBConnection, tableName string) (uniqueKeys []*vrepl.UniqueKey, err error) { - query, err := sqlparser.ParseAndBind(sqlSelectUniqueKeys, - sqltypes.StringBindVariable(v.dbName), - sqltypes.StringBindVariable(tableName), - sqltypes.StringBindVariable(v.dbName), - sqltypes.StringBindVariable(tableName), - ) - if err != nil { - return nil, err - } - rs, err := conn.ExecuteFetch(query, -1, true) - if err != nil { - return nil, err - } - for _, row := range rs.Named().Rows { - uniqueKey := &vrepl.UniqueKey{ - Name: row.AsString("index_name", ""), - Columns: *vrepl.ParseColumnList(row.AsString("column_names", "")), - HasNullable: row.AsBool("has_nullable", false), - HasSubpart: row.AsBool("has_subpart", false), - HasFloat: row.AsBool("is_float", false), - IsAutoIncrement: row.AsBool("is_auto_increment", false), - } - uniqueKeys = append(uniqueKeys, uniqueKey) - } - return uniqueKeys, nil +func (v *VRepl) targetTableName() string { + return v.targetCreateTableEntity.Name() } // isFastAnalyzeTableSupported checks if the underlying MySQL server supports 'fast_analyze_table', @@ -307,255 +216,64 @@ func (v *VRepl) readTableStatus(ctx context.Context, conn *dbconnpool.DBConnecti return tableRows, err } -// applyColumnTypes -func (v *VRepl) applyColumnTypes(ctx context.Context, conn *dbconnpool.DBConnection, tableName string, columnsLists ...*vrepl.ColumnList) error { - query, err := sqlparser.ParseAndBind(sqlSelectColumnTypes, - sqltypes.StringBindVariable(v.dbName), - sqltypes.StringBindVariable(tableName), - ) - if err != nil { - return err - } - rs, err := conn.ExecuteFetch(query, -1, true) - if err != nil { - return err - } - for _, row := range rs.Named().Rows { - columnName := row["COLUMN_NAME"].ToString() - columnType := row["COLUMN_TYPE"].ToString() - columnOctetLength := row.AsUint64("CHARACTER_OCTET_LENGTH", 0) - - for _, columnsList := range columnsLists { - column := columnsList.GetColumn(columnName) - if column == nil { - continue - } - - column.DataType = row.AsString("DATA_TYPE", "") // a more canonical form of column_type - column.IsNullable = (row.AsString("IS_NULLABLE", "") == "YES") - column.IsDefaultNull = row.AsBool("is_default_null", false) - - column.CharacterMaximumLength = row.AsInt64("CHARACTER_MAXIMUM_LENGTH", 0) - column.NumericPrecision = row.AsInt64("NUMERIC_PRECISION", 0) - column.NumericScale = row.AsInt64("NUMERIC_SCALE", 0) - column.DateTimePrecision = row.AsInt64("DATETIME_PRECISION", -1) - - column.Type = vrepl.UnknownColumnType - if strings.Contains(columnType, "unsigned") { - column.IsUnsigned = true - } - if strings.Contains(columnType, "mediumint") { - column.SetTypeIfUnknown(vrepl.MediumIntColumnType) - } - if strings.Contains(columnType, "timestamp") { - column.SetTypeIfUnknown(vrepl.TimestampColumnType) - } - if strings.Contains(columnType, "datetime") { - column.SetTypeIfUnknown(vrepl.DateTimeColumnType) - } - if strings.Contains(columnType, "json") { - column.SetTypeIfUnknown(vrepl.JSONColumnType) - } - if strings.Contains(columnType, "float") { - column.SetTypeIfUnknown(vrepl.FloatColumnType) - } - if strings.Contains(columnType, "double") { - column.SetTypeIfUnknown(vrepl.DoubleColumnType) - } - if strings.HasPrefix(columnType, "enum") { - column.SetTypeIfUnknown(vrepl.EnumColumnType) - column.EnumValues = schema.ParseEnumValues(columnType) - } - if strings.HasPrefix(columnType, "set(") { - column.SetTypeIfUnknown(vrepl.SetColumnType) - column.EnumValues = schema.ParseSetValues(columnType) - } - if strings.HasPrefix(columnType, "binary") { - column.SetTypeIfUnknown(vrepl.BinaryColumnType) - column.BinaryOctetLength = columnOctetLength - } - if charset := row.AsString("CHARACTER_SET_NAME", ""); charset != "" { - column.Charset = charset - } - if collation := row.AsString("COLLATION_NAME", ""); collation != "" { - column.SetTypeIfUnknown(vrepl.StringColumnType) - column.Collation = collation +// formalizeColumns +func formalizeColumns(columnsLists ...*schemadiff.ColumnDefinitionEntityList) error { + for _, colList := range columnsLists { + for _, col := range colList.Entities { + col.SetExplicitDefaultAndNull() + if err := col.SetExplicitCharsetCollate(); err != nil { + return err } } } return nil } -func (v *VRepl) analyzeAlter(ctx context.Context) error { - if v.alterQuery == nil { - // Happens for REVERT - return nil - } - v.parser.AnalyzeAlter(v.alterQuery) - if v.parser.IsRenameTable() { - return fmt.Errorf("Renaming the table is not supported in ALTER TABLE: %s", sqlparser.CanonicalString(v.alterQuery)) +func (v *VRepl) analyzeAlter() error { + if v.alterTableAnalysis.IsRenameTable { + return fmt.Errorf("renaming the table is not supported in ALTER TABLE") } return nil } -func (v *VRepl) analyzeTables(ctx context.Context, conn *dbconnpool.DBConnection) (err error) { - if v.analyzeTable { - if err := v.executeAnalyzeTable(ctx, conn, v.sourceTable); err != nil { - return err - } - } - v.tableRows, err = v.readTableStatus(ctx, conn, v.sourceTable) +func (v *VRepl) analyzeTables() (err error) { + analysis, err := schemadiff.OnlineDDLMigrationTablesAnalysis(v.sourceCreateTableEntity, v.targetCreateTableEntity, v.alterTableAnalysis) if err != nil { return err } - // columns: - sourceColumns, sourceVirtualColumns, sourcePKColumns, err := v.readTableColumns(ctx, conn, v.sourceTable) - if err != nil { - return err - } - targetColumns, targetVirtualColumns, targetPKColumns, err := v.readTableColumns(ctx, conn, v.targetTable) - if err != nil { - return err - } - v.sourceSharedColumns, v.targetSharedColumns, v.droppedSourceNonGeneratedColumns, v.sharedColumnsMap = vrepl.GetSharedColumns(sourceColumns, targetColumns, sourceVirtualColumns, targetVirtualColumns, v.parser) + v.analysis = analysis - // unique keys - sourceUniqueKeys, err := v.readTableUniqueKeys(ctx, conn, v.sourceTable) - if err != nil { - return err - } - if len(sourceUniqueKeys) == 0 { - return fmt.Errorf("Found no possible unique key on `%s`", v.sourceTable) - } - targetUniqueKeys, err := v.readTableUniqueKeys(ctx, conn, v.targetTable) - if err != nil { - return err - } - if len(targetUniqueKeys) == 0 { - return fmt.Errorf("Found no possible unique key on `%s`", v.targetTable) - } - v.chosenSourceUniqueKey, v.chosenTargetUniqueKey = vrepl.GetSharedUniqueKeys(sourceUniqueKeys, targetUniqueKeys, v.parser.ColumnRenameMap()) - if v.chosenSourceUniqueKey == nil { - // VReplication supports completely different unique keys on source and target, covering - // some/completely different columns. The condition is that the key on source - // must use columns which all exist on target table. - v.chosenSourceUniqueKey = vrepl.GetUniqueKeyCoveredByColumns(sourceUniqueKeys, v.sourceSharedColumns) - if v.chosenSourceUniqueKey == nil { - // Still no luck. - return fmt.Errorf("Found no possible unique key on `%s` whose columns are in target table `%s`", v.sourceTable, v.targetTable) - } - } - if v.chosenTargetUniqueKey == nil { - // VReplication supports completely different unique keys on source and target, covering - // some/completely different columns. The condition is that the key on target - // must use columns which all exist on source table. - v.chosenTargetUniqueKey = vrepl.GetUniqueKeyCoveredByColumns(targetUniqueKeys, v.targetSharedColumns) - if v.chosenTargetUniqueKey == nil { - // Still no luck. - return fmt.Errorf("Found no possible unique key on `%s` whose columns are in source table `%s`", v.targetTable, v.sourceTable) - } - } - if v.chosenSourceUniqueKey == nil || v.chosenTargetUniqueKey == nil { - return fmt.Errorf("Found no shared, not nullable, unique keys between `%s` and `%s`", v.sourceTable, v.targetTable) - } - v.addedUniqueKeys = vrepl.AddedUniqueKeys(sourceUniqueKeys, targetUniqueKeys, v.parser.ColumnRenameMap()) - v.removedUniqueKeys = vrepl.RemovedUniqueKeys(sourceUniqueKeys, targetUniqueKeys, v.parser.ColumnRenameMap()) - v.removedForeignKeyNames, err = vrepl.RemovedForeignKeyNames(v.env, v.originalCreateTable, v.vreplCreateTable) - if err != nil { - return err - } - - // chosen source & target unique keys have exact columns in same order - sharedPKColumns := &v.chosenSourceUniqueKey.Columns - - if err := v.applyColumnTypes(ctx, conn, v.sourceTable, sourceColumns, sourceVirtualColumns, sourcePKColumns, v.sourceSharedColumns, sharedPKColumns, v.droppedSourceNonGeneratedColumns); err != nil { - return err - } - if err := v.applyColumnTypes(ctx, conn, v.targetTable, targetColumns, targetVirtualColumns, targetPKColumns, v.targetSharedColumns); err != nil { - return err - } - - for _, sourcePKColumn := range sharedPKColumns.Columns() { - mappedColumn := v.targetSharedColumns.GetColumn(sourcePKColumn.Name) - if sourcePKColumn.Type == vrepl.EnumColumnType && mappedColumn.Type == vrepl.EnumColumnType { - // An ENUM as part of PRIMARY KEY. We must convert it to text because OMG that's complicated. - // There's a scenario where a query may modify the enum value (and it's bad practice, seeing - // that it's part of the PK, but it's still valid), and in that case we must have the string value - // to be able to DELETE the old row - v.targetSharedColumns.SetEnumToTextConversion(mappedColumn.Name, sourcePKColumn.EnumValues) - v.enumToTextMap[sourcePKColumn.Name] = sourcePKColumn.EnumValues - } - } - - for i := range v.sourceSharedColumns.Columns() { - sourceColumn := v.sourceSharedColumns.Columns()[i] - mappedColumn := v.targetSharedColumns.Columns()[i] - if sourceColumn.Type == vrepl.EnumColumnType { - switch { - // Either this is an ENUM column that stays an ENUM, or it is converted to a textual type. - // We take note of the enum values, and make it available in vreplication's Filter.Rule.ConvertEnumToText. - // This, in turn, will be used by vplayer (in TablePlan) like so: - // - In the binary log, enum values are integers. - // - Upon seeing this map, PlanBuilder will convert said int to the enum's logical string value. - // - And will apply the value as a string (`StringBindVariable`) in the query. - // What this allows is for enum values to have different ordering in the before/after table schema, - // so that for example you could modify an enum column: - // - from `('red', 'green', 'blue')` to `('red', 'blue')` - // - from `('red', 'green', 'blue')` to `('blue', 'red', 'green')` - case mappedColumn.Type == vrepl.EnumColumnType: - v.enumToTextMap[sourceColumn.Name] = sourceColumn.EnumValues - case mappedColumn.Charset != "": - v.enumToTextMap[sourceColumn.Name] = sourceColumn.EnumValues - v.targetSharedColumns.SetEnumToTextConversion(mappedColumn.Name, sourceColumn.EnumValues) - } - } + return nil +} - if sourceColumn.IsIntegralType() && mappedColumn.Type == vrepl.EnumColumnType { - v.intToEnumMap[sourceColumn.Name] = true +// analyzeTableStatus reads information from SHOW TABLE STATUS +func (v *VRepl) analyzeTableStatus(ctx context.Context, conn *dbconnpool.DBConnection) (err error) { + if v.analyzeTable { + if err := v.executeAnalyzeTable(ctx, conn, v.sourceTableName()); err != nil { + return err } } - - v.droppedNoDefaultColumnNames = vrepl.GetNoDefaultColumnNames(v.droppedSourceNonGeneratedColumns) - var expandedDescriptions map[string]string - v.expandedColumnNames, expandedDescriptions = vrepl.GetExpandedColumnNames(v.sourceSharedColumns, v.targetSharedColumns) - - v.sourceAutoIncrement, err = v.readAutoIncrement(ctx, conn, v.sourceTable) - - notes := []string{} - for _, uk := range v.removedUniqueKeys { - notes = append(notes, fmt.Sprintf("unique constraint removed: %s", uk.Name)) - } - for _, name := range v.droppedNoDefaultColumnNames { - notes = append(notes, fmt.Sprintf("column %s dropped, and had no default value", name)) - } - for _, name := range v.expandedColumnNames { - notes = append(notes, fmt.Sprintf("column %s: %s", name, expandedDescriptions[name])) - } - for _, name := range v.removedForeignKeyNames { - notes = append(notes, fmt.Sprintf("foreign key %s dropped", name)) - } - v.revertibleNotes = strings.Join(notes, "\n") + v.tableRows, err = v.readTableStatus(ctx, conn, v.sourceTableName()) if err != nil { return err } - return nil } // generateFilterQuery creates a SELECT query used by vreplication as a filter. It SELECTs all // non-generated columns between source & target tables, and takes care of column renames. -func (v *VRepl) generateFilterQuery(ctx context.Context) error { - if v.sourceSharedColumns.Len() == 0 { - return fmt.Errorf("Empty column list") +func (v *VRepl) generateFilterQuery() error { + if v.analysis.SourceSharedColumns.Len() == 0 { + return fmt.Errorf("empty column list") } var sb strings.Builder sb.WriteString("select ") - for i, sourceCol := range v.sourceSharedColumns.Columns() { - name := sourceCol.Name - targetName := v.sharedColumnsMap[name] + for i, sourceCol := range v.analysis.SourceSharedColumns.Entities { + name := sourceCol.Name() + targetName := v.analysis.SharedColumnsMap[name] - targetCol := v.targetSharedColumns.GetColumn(targetName) + targetCol := v.analysis.TargetSharedColumns.GetColumn(targetName) if targetCol == nil { return vterrors.Errorf(vtrpcpb.Code_INTERNAL, "Cannot find target column %s", targetName) } @@ -564,35 +282,36 @@ func (v *VRepl) generateFilterQuery(ctx context.Context) error { sb.WriteString(", ") } switch { - case sourceCol.EnumToTextConversion: + case sourceCol.HasEnumValues(): + // Source is `enum` or `set`. We always take the textual represenation rather than the numeric one. sb.WriteString(fmt.Sprintf("CONCAT(%s)", escapeName(name))) - case v.intToEnumMap[name]: + case v.analysis.IntToEnumMap[name]: sb.WriteString(fmt.Sprintf("CONCAT(%s)", escapeName(name))) - case sourceCol.Type == vrepl.JSONColumnType: + case sourceCol.Type() == "json": sb.WriteString(fmt.Sprintf("convert(%s using utf8mb4)", escapeName(name))) - case sourceCol.Type == vrepl.StringColumnType: + case sourceCol.IsTextual(): // Check source and target charset/encoding. If needed, create // a binlogdatapb.CharsetConversion entry (later written to vreplication) - fromCollation := v.env.CollationEnv().DefaultCollationForCharset(sourceCol.Charset) + fromCollation := v.env.CollationEnv().DefaultCollationForCharset(sourceCol.Charset()) if fromCollation == collations.Unknown { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Character set %s not supported for column %s", sourceCol.Charset, sourceCol.Name) + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Character set %s not supported for column %s", sourceCol.Charset(), sourceCol.Name()) } - toCollation := v.env.CollationEnv().DefaultCollationForCharset(targetCol.Charset) + toCollation := v.env.CollationEnv().DefaultCollationForCharset(targetCol.Charset()) // Let's see if target col is at all textual - if targetCol.Type == vrepl.StringColumnType && toCollation == collations.Unknown { - return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Character set %s not supported for column %s", targetCol.Charset, targetCol.Name) + if targetCol.IsTextual() && toCollation == collations.Unknown { + return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "Character set %s not supported for column %s", targetCol.Charset(), targetCol.Name()) } - if trivialCharset(fromCollation) && trivialCharset(toCollation) && targetCol.Type != vrepl.JSONColumnType { + if trivialCharset(fromCollation) && trivialCharset(toCollation) && targetCol.Type() != "json" { sb.WriteString(escapeName(name)) } else { v.convertCharset[targetName] = &binlogdatapb.CharsetConversion{ - FromCharset: sourceCol.Charset, - ToCharset: targetCol.Charset, + FromCharset: sourceCol.Charset(), + ToCharset: targetCol.Charset(), } sb.WriteString(fmt.Sprintf("convert(%s using utf8mb4)", escapeName(name))) } - case targetCol.Type == vrepl.JSONColumnType && sourceCol.Type != vrepl.JSONColumnType: + case targetCol.Type() == "json" && sourceCol.Type() != "json": // Convert any type to JSON: encode the type as utf8mb4 text sb.WriteString(fmt.Sprintf("convert(%s using utf8mb4)", escapeName(name))) default: @@ -602,7 +321,7 @@ func (v *VRepl) generateFilterQuery(ctx context.Context) error { sb.WriteString(escapeName(targetName)) } sb.WriteString(" from ") - sb.WriteString(escapeName(v.sourceTable)) + sb.WriteString(escapeName(v.sourceTableName())) v.filterQuery = sb.String() return nil @@ -624,25 +343,22 @@ func (v *VRepl) analyzeBinlogSource(ctx context.Context) { StopAfterCopy: false, } - encodeColumns := func(columns *vrepl.ColumnList) string { - return textutil.EscapeJoin(columns.Names(), ",") + encodeColumns := func(names []string) string { + return textutil.EscapeJoin(names, ",") } rule := &binlogdatapb.Rule{ - Match: v.targetTable, + Match: v.targetTableName(), Filter: v.filterQuery, - SourceUniqueKeyColumns: encodeColumns(&v.chosenSourceUniqueKey.Columns), - TargetUniqueKeyColumns: encodeColumns(&v.chosenTargetUniqueKey.Columns), - SourceUniqueKeyTargetColumns: encodeColumns(v.chosenSourceUniqueKey.Columns.MappedNamesColumnList(v.sharedColumnsMap)), - ForceUniqueKey: url.QueryEscape(v.chosenSourceUniqueKey.Name), + SourceUniqueKeyColumns: encodeColumns(v.analysis.ChosenSourceUniqueKey.ColumnList.Names()), + TargetUniqueKeyColumns: encodeColumns(v.analysis.ChosenTargetUniqueKey.ColumnList.Names()), + SourceUniqueKeyTargetColumns: encodeColumns(schemadiff.MappedColumnNames(v.analysis.ChosenSourceUniqueKey.ColumnList, v.analysis.SharedColumnsMap)), + ForceUniqueKey: url.QueryEscape(v.analysis.ChosenSourceUniqueKey.Name()), } if len(v.convertCharset) > 0 { rule.ConvertCharset = v.convertCharset } - if len(v.enumToTextMap) > 0 { - rule.ConvertEnumToText = v.enumToTextMap - } - if len(v.intToEnumMap) > 0 { - rule.ConvertIntToEnum = v.intToEnumMap + if len(v.analysis.IntToEnumMap) > 0 { + rule.ConvertIntToEnum = v.analysis.IntToEnumMap } bls.Filter.Rules = append(bls.Filter.Rules, rule) @@ -650,13 +366,16 @@ func (v *VRepl) analyzeBinlogSource(ctx context.Context) { } func (v *VRepl) analyze(ctx context.Context, conn *dbconnpool.DBConnection) error { - if err := v.analyzeAlter(ctx); err != nil { + if err := v.analyzeAlter(); err != nil { + return err + } + if err := v.analyzeTables(); err != nil { return err } - if err := v.analyzeTables(ctx, conn); err != nil { + if err := v.generateFilterQuery(); err != nil { return err } - if err := v.generateFilterQuery(ctx); err != nil { + if err := v.analyzeTableStatus(ctx, conn); err != nil { return err } v.analyzeBinlogSource(ctx) @@ -664,7 +383,7 @@ func (v *VRepl) analyze(ctx context.Context, conn *dbconnpool.DBConnection) erro } // generateInsertStatement generates the INSERT INTO _vt.replication statement that creates the vreplication workflow -func (v *VRepl) generateInsertStatement(ctx context.Context) (string, error) { +func (v *VRepl) generateInsertStatement() (string, error) { ig := vreplication.NewInsertGenerator(binlogdatapb.VReplicationWorkflowState_Stopped, v.dbName) ig.AddRow(v.workflow, v.bls, v.pos, "", "in_order:REPLICA,PRIMARY", binlogdatapb.VReplicationWorkflowType_OnlineDDL, binlogdatapb.VReplicationWorkflowSubType_None, false) @@ -673,7 +392,7 @@ func (v *VRepl) generateInsertStatement(ctx context.Context) (string, error) { } // generateStartStatement Generates the statement to start VReplication running on the workflow -func (v *VRepl) generateStartStatement(ctx context.Context) (string, error) { +func (v *VRepl) generateStartStatement() (string, error) { return sqlparser.ParseAndBind(sqlStartVReplStream, sqltypes.StringBindVariable(v.dbName), sqltypes.StringBindVariable(v.workflow), diff --git a/go/vt/vttablet/onlineddl/vrepl/columns.go b/go/vt/vttablet/onlineddl/vrepl/columns.go deleted file mode 100644 index f2bb8f6d3f2..00000000000 --- a/go/vt/vttablet/onlineddl/vrepl/columns.go +++ /dev/null @@ -1,208 +0,0 @@ -/* - Copyright 2016 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE -*/ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vrepl - -import ( - "fmt" - "strings" - - "vitess.io/vitess/go/vt/schema" -) - -// expandedDataTypes maps some known and difficult-to-compute by INFORMATION_SCHEMA data types which expand other data types. -// For example, in "date:datetime", datetime expands date because it has more precision. In "timestamp:date" date expands timestamp -// because it can contain years not covered by timestamp. -var expandedDataTypes = map[string]bool{ - "time:datetime": true, - "date:datetime": true, - "timestamp:datetime": true, - "time:timestamp": true, - "date:timestamp": true, - "timestamp:date": true, -} - -// GetSharedColumns returns the intersection of two lists of columns in same order as the first list -func GetSharedColumns( - sourceColumns, targetColumns *ColumnList, - sourceVirtualColumns, targetVirtualColumns *ColumnList, - parser *AlterTableParser, -) ( - sourceSharedColumns *ColumnList, - targetSharedColumns *ColumnList, - droppedSourceNonGeneratedColumns *ColumnList, - sharedColumnsMap map[string]string, -) { - sharedColumnNames := []string{} - droppedSourceNonGeneratedColumnsNames := []string{} - for _, sourceColumn := range sourceColumns.Names() { - isSharedColumn := false - isVirtualColumnOnSource := false - for _, targetColumn := range targetColumns.Names() { - if strings.EqualFold(sourceColumn, targetColumn) { - // both tables have this column. Good start. - isSharedColumn = true - break - } - if strings.EqualFold(parser.columnRenameMap[sourceColumn], targetColumn) { - // column in source is renamed in target - isSharedColumn = true - break - } - } - for droppedColumn := range parser.DroppedColumnsMap() { - if strings.EqualFold(sourceColumn, droppedColumn) { - isSharedColumn = false - break - } - } - for _, virtualColumn := range sourceVirtualColumns.Names() { - // virtual/generated columns on source are silently skipped - if strings.EqualFold(sourceColumn, virtualColumn) { - isSharedColumn = false - isVirtualColumnOnSource = true - } - } - for _, virtualColumn := range targetVirtualColumns.Names() { - // virtual/generated columns on target are silently skipped - if strings.EqualFold(sourceColumn, virtualColumn) { - isSharedColumn = false - } - } - if isSharedColumn { - sharedColumnNames = append(sharedColumnNames, sourceColumn) - } else if !isVirtualColumnOnSource { - droppedSourceNonGeneratedColumnsNames = append(droppedSourceNonGeneratedColumnsNames, sourceColumn) - } - } - sharedColumnsMap = map[string]string{} - for _, columnName := range sharedColumnNames { - if mapped, ok := parser.columnRenameMap[columnName]; ok { - sharedColumnsMap[columnName] = mapped - } else { - sharedColumnsMap[columnName] = columnName - } - } - mappedSharedColumnNames := []string{} - for _, columnName := range sharedColumnNames { - mappedSharedColumnNames = append(mappedSharedColumnNames, sharedColumnsMap[columnName]) - } - return NewColumnList(sharedColumnNames), NewColumnList(mappedSharedColumnNames), NewColumnList(droppedSourceNonGeneratedColumnsNames), sharedColumnsMap -} - -// isExpandedColumn sees if target column has any value set/range that is impossible in source column. See GetExpandedColumns comment for examples -func isExpandedColumn(sourceColumn *Column, targetColumn *Column) (bool, string) { - if targetColumn.IsNullable && !sourceColumn.IsNullable { - return true, "target is NULL-able, source is not" - } - if targetColumn.CharacterMaximumLength > sourceColumn.CharacterMaximumLength { - return true, "increased CHARACTER_MAXIMUM_LENGTH" - } - if targetColumn.NumericPrecision > sourceColumn.NumericPrecision { - return true, "increased NUMERIC_PRECISION" - } - if targetColumn.NumericScale > sourceColumn.NumericScale { - return true, "increased NUMERIC_SCALE" - } - if targetColumn.DateTimePrecision > sourceColumn.DateTimePrecision { - return true, "increased DATETIME_PRECISION" - } - if sourceColumn.IsNumeric() && targetColumn.IsNumeric() { - if sourceColumn.IsUnsigned && !targetColumn.IsUnsigned { - return true, "source is unsigned, target is signed" - } - if sourceColumn.NumericPrecision <= targetColumn.NumericPrecision && !sourceColumn.IsUnsigned && targetColumn.IsUnsigned { - // e.g. INT SIGNED => INT UNSIGNED, INT SIGNED => BIGINT UNSIGNED - return true, "target unsigned value exceeds source unsigned value" - } - if targetColumn.IsFloatingPoint() && !sourceColumn.IsFloatingPoint() { - return true, "target is floating point, source is not" - } - } - if expandedDataTypes[fmt.Sprintf("%s:%s", sourceColumn.DataType, targetColumn.DataType)] { - return true, "target is expanded data type of source" - } - if sourceColumn.Charset != targetColumn.Charset { - if targetColumn.Charset == "utf8mb4" { - return true, "expand character set to utf8mb4" - } - if strings.HasPrefix(targetColumn.Charset, "utf8") && !strings.HasPrefix(sourceColumn.Charset, "utf8") { - // not utf to utf - return true, "expand character set to utf8" - } - } - for _, colType := range []ColumnType{EnumColumnType, SetColumnType} { - // enums and sets have very similar properties, and are practically identical in our analysis - if sourceColumn.Type == colType { - // this is an enum or a set - if targetColumn.Type != colType { - return true, "conversion from enum/set to non-enum/set adds potential values" - } - // target is an enum or a set. See if all values on target exist in source - sourceEnumTokensMap := schema.ParseEnumOrSetTokensMap(sourceColumn.EnumValues) - targetEnumTokensMap := schema.ParseEnumOrSetTokensMap(targetColumn.EnumValues) - for k, v := range targetEnumTokensMap { - if sourceEnumTokensMap[k] != v { - return true, "target enum/set expands source enum/set" - } - } - } - } - return false, "" -} - -// GetExpandedColumnNames is given source and target shared columns, and returns the list of columns whose data type is expanded. -// An expanded data type is one where the target can have a value which the source does not. Examples: -// - any NOT NULL to NULLable (a NULL in the target cannot appear on source) -// - INT -> BIGINT (obvious) -// - BIGINT UNSIGNED -> INT SIGNED (negative values) -// - TIMESTAMP -> TIMESTAMP(3) -// etc. -func GetExpandedColumnNames( - sourceSharedColumns *ColumnList, - targetSharedColumns *ColumnList, -) ( - expandedColumnNames []string, - expandedDescriptions map[string]string, -) { - expandedDescriptions = map[string]string{} - for i := range sourceSharedColumns.Columns() { - // source and target columns assumed to be mapped 1:1, same length - sourceColumn := sourceSharedColumns.Columns()[i] - targetColumn := targetSharedColumns.Columns()[i] - - if isExpanded, description := isExpandedColumn(&sourceColumn, &targetColumn); isExpanded { - expandedColumnNames = append(expandedColumnNames, sourceColumn.Name) - expandedDescriptions[sourceColumn.Name] = description - } - } - return expandedColumnNames, expandedDescriptions -} - -// GetNoDefaultColumnNames returns names of columns which have no default value, out of given list of columns -func GetNoDefaultColumnNames(columns *ColumnList) (names []string) { - names = []string{} - for _, col := range columns.Columns() { - if !col.HasDefault() { - names = append(names, col.Name) - } - } - return names -} diff --git a/go/vt/vttablet/onlineddl/vrepl/columns_test.go b/go/vt/vttablet/onlineddl/vrepl/columns_test.go deleted file mode 100644 index 32efd104cc1..00000000000 --- a/go/vt/vttablet/onlineddl/vrepl/columns_test.go +++ /dev/null @@ -1,380 +0,0 @@ -/* - Copyright 2016 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE -*/ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vrepl - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -var ( - columnsA = &ColumnList{ - columns: []Column{ - { - Name: "id", - }, - { - Name: "cint", - }, - { - Name: "cgen1", - }, - { - Name: "cgen2", - }, - { - Name: "cchar", - }, - { - Name: "cremoved", - }, - { - Name: "cnullable", - IsNullable: true, - }, - { - Name: "cnodefault", - IsNullable: false, - IsDefaultNull: true, - }, - }, - Ordinals: ColumnsMap{}, - } - columnsB = &ColumnList{ - columns: []Column{ - { - Name: "id", - }, - { - Name: "cint", - }, - { - Name: "cgen1", - }, - { - Name: "cchar_alternate", - }, - { - Name: "cnullable", - IsNullable: true, - }, - { - Name: "cnodefault", - IsNullable: false, - IsDefaultNull: true, - }, - }, - Ordinals: ColumnsMap{}, - } - columnsVirtual = ParseColumnList("cgen1,cgen2") -) - -func TestGetSharedColumns(t *testing.T) { - tt := []struct { - name string - sourceCols *ColumnList - targetCols *ColumnList - renameMap map[string]string - expectSourceSharedColNames []string - expectTargetSharedColNames []string - expectDroppedSourceNonGeneratedColNames []string - }{ - { - name: "rename map empty", - sourceCols: columnsA, - targetCols: columnsB, - renameMap: map[string]string{}, - expectSourceSharedColNames: []string{"id", "cint", "cnullable", "cnodefault"}, - expectTargetSharedColNames: []string{"id", "cint", "cnullable", "cnodefault"}, - expectDroppedSourceNonGeneratedColNames: []string{"cchar", "cremoved"}, - }, - { - name: "renamed column", - sourceCols: columnsA, - targetCols: columnsB, - renameMap: map[string]string{"cchar": "cchar_alternate"}, - expectSourceSharedColNames: []string{"id", "cint", "cchar", "cnullable", "cnodefault"}, - expectTargetSharedColNames: []string{"id", "cint", "cchar_alternate", "cnullable", "cnodefault"}, - expectDroppedSourceNonGeneratedColNames: []string{"cremoved"}, - }, - } - - parser := NewAlterTableParser() - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - parser.columnRenameMap = tc.renameMap - sourceSharedCols, targetSharedCols, droppedNonGeneratedCols, _ := GetSharedColumns( - tc.sourceCols, tc.targetCols, - columnsVirtual, columnsVirtual, - parser, - ) - assert.Equal(t, tc.expectSourceSharedColNames, sourceSharedCols.Names()) - assert.Equal(t, tc.expectTargetSharedColNames, targetSharedCols.Names()) - assert.Equal(t, tc.expectDroppedSourceNonGeneratedColNames, droppedNonGeneratedCols.Names()) - }) - } -} - -func TestGetExpandedColumnNames(t *testing.T) { - var ( - columnsA = &ColumnList{ - columns: []Column{ - { - Name: "c1", - IsNullable: true, - }, - { - Name: "c2", - IsNullable: true, - }, - { - Name: "c3", - IsNullable: false, - }, - }, - Ordinals: ColumnsMap{}, - } - columnsB = &ColumnList{ - columns: []Column{ - { - Name: "c1", - IsNullable: true, - }, - { - Name: "c2", - IsNullable: false, - }, - { - Name: "c3", - IsNullable: true, - }, - }, - Ordinals: ColumnsMap{}, - } - ) - tcases := []struct { - name string - sourceCol Column - targetCol Column - expanded bool - }{ - { - "both nullable", - Column{ - IsNullable: true, - }, - Column{ - IsNullable: true, - }, - false, - }, - { - "nullable to non nullable", - Column{ - IsNullable: true, - }, - Column{ - IsNullable: false, - }, - false, - }, - { - "non nullable to nullable", - Column{ - IsNullable: false, - }, - Column{ - IsNullable: true, - }, - true, - }, - { - "signed to unsigned", - Column{ - Type: IntegerColumnType, - NumericPrecision: 4, - IsUnsigned: false, - }, - Column{ - Type: IntegerColumnType, - NumericPrecision: 4, - IsUnsigned: true, - }, - true, - }, - { - "unsigned to signed", - Column{ - Type: IntegerColumnType, - NumericPrecision: 4, - IsUnsigned: true, - }, - Column{ - Type: IntegerColumnType, - NumericPrecision: 4, - IsUnsigned: false, - }, - true, - }, - { - "signed to smaller unsigned", - Column{ - Type: IntegerColumnType, - NumericPrecision: 8, - IsUnsigned: false, - }, - Column{ - Type: IntegerColumnType, - NumericPrecision: 4, - IsUnsigned: true, - }, - false, - }, - { - "same char length", - Column{ - CharacterMaximumLength: 20, - }, - Column{ - CharacterMaximumLength: 20, - }, - false, - }, - { - "reduced char length", - Column{ - CharacterMaximumLength: 20, - }, - Column{ - CharacterMaximumLength: 19, - }, - false, - }, - { - "increased char length", - Column{ - CharacterMaximumLength: 20, - }, - Column{ - CharacterMaximumLength: 21, - }, - true, - }, - { - "expand temporal", - Column{ - DataType: "time", - }, - Column{ - DataType: "timestamp", - }, - true, - }, - { - "expand temporal", - Column{ - DataType: "date", - }, - Column{ - DataType: "timestamp", - }, - true, - }, - { - "expand temporal", - Column{ - DataType: "date", - }, - Column{ - DataType: "datetime", - }, - true, - }, - { - "non expand temporal", - Column{ - DataType: "datetime", - }, - Column{ - DataType: "timestamp", - }, - false, - }, - { - "expand temporal", - Column{ - DataType: "timestamp", - }, - Column{ - DataType: "datetime", - }, - true, - }, - { - "expand enum", - Column{ - Type: EnumColumnType, - EnumValues: "'a','b'", - }, - Column{ - Type: EnumColumnType, - EnumValues: "'a','x'", - }, - true, - }, - { - "expand enum", - Column{ - Type: EnumColumnType, - EnumValues: "'a','b'", - }, - Column{ - Type: EnumColumnType, - EnumValues: "'a','b','c'", - }, - true, - }, - { - "reduce enum", - Column{ - Type: EnumColumnType, - EnumValues: "'a','b','c'", - }, - Column{ - Type: EnumColumnType, - EnumValues: "'a','b'", - }, - false, - }, - } - - expectedExpandedColumnNames := []string{"c3"} - expandedColumnNames, _ := GetExpandedColumnNames(columnsA, columnsB) - assert.Equal(t, expectedExpandedColumnNames, expandedColumnNames) - - for _, tcase := range tcases { - t.Run(tcase.name, func(t *testing.T) { - expanded, _ := isExpandedColumn(&tcase.sourceCol, &tcase.targetCol) - assert.Equal(t, tcase.expanded, expanded) - }) - } -} diff --git a/go/vt/vttablet/onlineddl/vrepl/foreign_key.go b/go/vt/vttablet/onlineddl/vrepl/foreign_key.go deleted file mode 100644 index 006beb7345c..00000000000 --- a/go/vt/vttablet/onlineddl/vrepl/foreign_key.go +++ /dev/null @@ -1,58 +0,0 @@ -/* - Copyright 2016 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE -*/ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vrepl - -import ( - "vitess.io/vitess/go/vt/schemadiff" - "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/vtenv" -) - -// RemovedForeignKeyNames returns the names of removed foreign keys, ignoring mere name changes -func RemovedForeignKeyNames( - venv *vtenv.Environment, - originalCreateTable *sqlparser.CreateTable, - vreplCreateTable *sqlparser.CreateTable, -) (names []string, err error) { - if originalCreateTable == nil || vreplCreateTable == nil { - return nil, nil - } - env := schemadiff.NewEnv(venv, venv.CollationEnv().DefaultConnectionCharset()) - diffHints := schemadiff.DiffHints{ - ConstraintNamesStrategy: schemadiff.ConstraintNamesIgnoreAll, - } - diff, err := schemadiff.DiffTables(env, originalCreateTable, vreplCreateTable, &diffHints) - if err != nil { - return nil, err - } - - validateWalk := func(node sqlparser.SQLNode) (kontinue bool, err error) { - switch node := node.(type) { - case *sqlparser.DropKey: - if node.Type == sqlparser.ForeignKeyType { - names = append(names, node.Name.String()) - } - } - return true, nil - } - _ = sqlparser.Walk(validateWalk, diff.Statement()) // We never return an error - return names, nil -} diff --git a/go/vt/vttablet/onlineddl/vrepl/foreign_key_test.go b/go/vt/vttablet/onlineddl/vrepl/foreign_key_test.go deleted file mode 100644 index 66775092dcb..00000000000 --- a/go/vt/vttablet/onlineddl/vrepl/foreign_key_test.go +++ /dev/null @@ -1,91 +0,0 @@ -/* - Copyright 2016 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE -*/ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vrepl - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "vitess.io/vitess/go/vt/sqlparser" - "vitess.io/vitess/go/vt/vtenv" -) - -func TestRemovedForeignKeyNames(t *testing.T) { - - tcases := []struct { - before string - after string - names []string - }{ - { - before: "create table t (id int primary key)", - after: "create table t (id2 int primary key, i int)", - }, - { - before: "create table t (id int primary key)", - after: "create table t2 (id2 int primary key, i int)", - }, - { - before: "create table t (id int primary key, i int, constraint f foreign key (i) references parent (id) on delete cascade)", - after: "create table t (id int primary key, i int, constraint f foreign key (i) references parent (id) on delete cascade)", - }, - { - before: "create table t (id int primary key, i int, constraint f1 foreign key (i) references parent (id) on delete cascade)", - after: "create table t (id int primary key, i int, constraint f2 foreign key (i) references parent (id) on delete cascade)", - }, - { - before: "create table t (id int primary key, i int, constraint f foreign key (i) references parent (id) on delete cascade)", - after: "create table t (id int primary key, i int)", - names: []string{"f"}, - }, - { - before: "create table t (id int primary key, i int, i2 int, constraint f1 foreign key (i) references parent (id) on delete cascade, constraint fi2 foreign key (i2) references parent (id) on delete cascade)", - after: "create table t (id int primary key, i int, i2 int, constraint f2 foreign key (i) references parent (id) on delete cascade)", - names: []string{"fi2"}, - }, - { - before: "create table t1 (id int primary key, i int, constraint `check1` CHECK ((`i` < 5)))", - after: "create table t2 (id int primary key, i int)", - }, - } - for _, tcase := range tcases { - t.Run(tcase.before, func(t *testing.T) { - env := vtenv.NewTestEnv() - beforeStmt, err := env.Parser().ParseStrictDDL(tcase.before) - require.NoError(t, err) - beforeCreateTable, ok := beforeStmt.(*sqlparser.CreateTable) - require.True(t, ok) - require.NotNil(t, beforeCreateTable) - - afterStmt, err := env.Parser().ParseStrictDDL(tcase.after) - require.NoError(t, err) - afterCreateTable, ok := afterStmt.(*sqlparser.CreateTable) - require.True(t, ok) - require.NotNil(t, afterCreateTable) - - names, err := RemovedForeignKeyNames(env, beforeCreateTable, afterCreateTable) - assert.NoError(t, err) - assert.Equal(t, tcase.names, names) - }) - } -} diff --git a/go/vt/vttablet/onlineddl/vrepl/parser.go b/go/vt/vttablet/onlineddl/vrepl/parser.go deleted file mode 100644 index f76f8735016..00000000000 --- a/go/vt/vttablet/onlineddl/vrepl/parser.go +++ /dev/null @@ -1,112 +0,0 @@ -/* - Copyright 2016 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE -*/ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vrepl - -import ( - "strings" - - "vitess.io/vitess/go/vt/sqlparser" -) - -// AlterTableParser is a parser tool for ALTER TABLE statements -// This is imported from gh-ost. In the future, we should replace that with Vitess parsing. -type AlterTableParser struct { - columnRenameMap map[string]string - droppedColumns map[string]bool - isRenameTable bool - isAutoIncrementDefined bool -} - -// NewAlterTableParser creates a new parser -func NewAlterTableParser() *AlterTableParser { - return &AlterTableParser{ - columnRenameMap: make(map[string]string), - droppedColumns: make(map[string]bool), - } -} - -// NewParserFromAlterStatement creates a new parser with a ALTER TABLE statement -func NewParserFromAlterStatement(alterTable *sqlparser.AlterTable) *AlterTableParser { - parser := NewAlterTableParser() - parser.AnalyzeAlter(alterTable) - return parser -} - -// AnalyzeAlter looks for specific changes in the AlterTable statement, that are relevant -// to OnlineDDL/VReplication -func (p *AlterTableParser) AnalyzeAlter(alterTable *sqlparser.AlterTable) { - for _, opt := range alterTable.AlterOptions { - switch opt := opt.(type) { - case *sqlparser.RenameTableName: - p.isRenameTable = true - case *sqlparser.DropColumn: - p.droppedColumns[opt.Name.Name.String()] = true - case *sqlparser.ChangeColumn: - if opt.OldColumn != nil && opt.NewColDefinition != nil { - oldName := opt.OldColumn.Name.String() - newName := opt.NewColDefinition.Name.String() - p.columnRenameMap[oldName] = newName - } - case sqlparser.TableOptions: - for _, tableOption := range opt { - if strings.ToUpper(tableOption.Name) == "AUTO_INCREMENT" { - p.isAutoIncrementDefined = true - } - } - } - } -} - -// GetNonTrivialRenames gets a list of renamed column -func (p *AlterTableParser) GetNonTrivialRenames() map[string]string { - result := make(map[string]string) - for column, renamed := range p.columnRenameMap { - if column != renamed { - result[column] = renamed - } - } - return result -} - -// HasNonTrivialRenames is true when columns have been renamed -func (p *AlterTableParser) HasNonTrivialRenames() bool { - return len(p.GetNonTrivialRenames()) > 0 -} - -// DroppedColumnsMap returns list of dropped columns -func (p *AlterTableParser) DroppedColumnsMap() map[string]bool { - return p.droppedColumns -} - -// IsRenameTable returns true when the ALTER TABLE statement includes renaming the table -func (p *AlterTableParser) IsRenameTable() bool { - return p.isRenameTable -} - -// IsAutoIncrementDefined returns true when alter options include an explicit AUTO_INCREMENT value -func (p *AlterTableParser) IsAutoIncrementDefined() bool { - return p.isAutoIncrementDefined -} - -// ColumnRenameMap returns the renamed column mapping -func (p *AlterTableParser) ColumnRenameMap() map[string]string { - return p.columnRenameMap -} diff --git a/go/vt/vttablet/onlineddl/vrepl/parser_test.go b/go/vt/vttablet/onlineddl/vrepl/parser_test.go deleted file mode 100644 index 93e2ef25a15..00000000000 --- a/go/vt/vttablet/onlineddl/vrepl/parser_test.go +++ /dev/null @@ -1,190 +0,0 @@ -/* - Copyright 2016 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE -*/ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vrepl - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "vitess.io/vitess/go/vt/sqlparser" -) - -func alterTableStatement(t *testing.T, sql string) *sqlparser.AlterTable { - stmt, err := sqlparser.NewTestParser().ParseStrictDDL(sql) - require.NoError(t, err) - alter, ok := stmt.(*sqlparser.AlterTable) - require.True(t, ok) - return alter -} - -func TestParseAlterStatement(t *testing.T) { - statement := "alter table t add column t int, engine=innodb" - alterStatement := alterTableStatement(t, statement) - parser := NewAlterTableParser() - parser.AnalyzeAlter(alterStatement) - assert.False(t, parser.HasNonTrivialRenames()) - assert.False(t, parser.IsAutoIncrementDefined()) -} - -func TestParseAlterStatementTrivialRename(t *testing.T) { - statement := "alter table t add column t int, change ts ts timestamp, engine=innodb" - alterStatement := alterTableStatement(t, statement) - parser := NewAlterTableParser() - parser.AnalyzeAlter(alterStatement) - assert.False(t, parser.HasNonTrivialRenames()) - assert.False(t, parser.IsAutoIncrementDefined()) - assert.Equal(t, len(parser.columnRenameMap), 1) - assert.Equal(t, parser.columnRenameMap["ts"], "ts") -} - -func TestParseAlterStatementWithAutoIncrement(t *testing.T) { - - statements := []string{ - "auto_increment=7", - "auto_increment = 7", - "AUTO_INCREMENT = 71", - "AUTO_INCREMENT 23", - "AUTO_INCREMENT 23", - "add column t int, change ts ts timestamp, auto_increment=7 engine=innodb", - "add column t int, change ts ts timestamp, auto_increment =7 engine=innodb", - "add column t int, change ts ts timestamp, AUTO_INCREMENT = 7 engine=innodb", - "add column t int, change ts ts timestamp, engine=innodb auto_increment=73425", - "add column t int, change ts ts timestamp, engine=innodb, auto_increment=73425", - "add column t int, change ts ts timestamp, engine=innodb, auto_increment 73425", - "add column t int, change ts ts timestamp, engine innodb, auto_increment 73425", - "add column t int, change ts ts timestamp, engine innodb auto_increment 73425", - } - for _, statement := range statements { - parser := NewAlterTableParser() - statement := "alter table t " + statement - alterStatement := alterTableStatement(t, statement) - parser.AnalyzeAlter(alterStatement) - assert.True(t, parser.IsAutoIncrementDefined()) - } -} - -func TestParseAlterStatementTrivialRenames(t *testing.T) { - statement := "alter table t add column t int, change ts ts timestamp, CHANGE f `f` float, engine=innodb" - alterStatement := alterTableStatement(t, statement) - parser := NewAlterTableParser() - parser.AnalyzeAlter(alterStatement) - assert.False(t, parser.HasNonTrivialRenames()) - assert.False(t, parser.IsAutoIncrementDefined()) - assert.Equal(t, len(parser.columnRenameMap), 2) - assert.Equal(t, parser.columnRenameMap["ts"], "ts") - assert.Equal(t, parser.columnRenameMap["f"], "f") -} - -func TestParseAlterStatementNonTrivial(t *testing.T) { - statements := []string{ - `add column b bigint, change f fl float, change i count int, engine=innodb`, - "add column b bigint, change column `f` fl float, change `i` `count` int, engine=innodb", - "add column b bigint, change column `f` fl float, change `i` `count` int, change ts ts timestamp, engine=innodb", - `change - f fl float, - CHANGE COLUMN i - count int, engine=innodb`, - } - - for _, statement := range statements { - statement := "alter table t " + statement - alterStatement := alterTableStatement(t, statement) - parser := NewAlterTableParser() - parser.AnalyzeAlter(alterStatement) - assert.False(t, parser.IsAutoIncrementDefined()) - renames := parser.GetNonTrivialRenames() - assert.Equal(t, len(renames), 2) - assert.Equal(t, renames["i"], "count") - assert.Equal(t, renames["f"], "fl") - } -} - -func TestParseAlterStatementDroppedColumns(t *testing.T) { - - { - parser := NewAlterTableParser() - statement := "alter table t drop column b" - alterStatement := alterTableStatement(t, statement) - parser.AnalyzeAlter(alterStatement) - assert.Equal(t, len(parser.droppedColumns), 1) - assert.True(t, parser.droppedColumns["b"]) - } - { - parser := NewAlterTableParser() - statement := "alter table t drop column b, drop key c_idx, drop column `d`" - alterStatement := alterTableStatement(t, statement) - parser.AnalyzeAlter(alterStatement) - assert.Equal(t, len(parser.droppedColumns), 2) - assert.True(t, parser.droppedColumns["b"]) - assert.True(t, parser.droppedColumns["d"]) - } - { - parser := NewAlterTableParser() - statement := "alter table t drop column b, drop key c_idx, drop column `d`, drop `e`, drop primary key, drop foreign key fk_1" - alterStatement := alterTableStatement(t, statement) - parser.AnalyzeAlter(alterStatement) - assert.Equal(t, len(parser.droppedColumns), 3) - assert.True(t, parser.droppedColumns["b"]) - assert.True(t, parser.droppedColumns["d"]) - assert.True(t, parser.droppedColumns["e"]) - } -} - -func TestParseAlterStatementRenameTable(t *testing.T) { - tt := []struct { - alter string - isRename bool - }{ - { - alter: "alter table t drop column b", - }, - { - alter: "alter table t rename as something_else", - isRename: true, - }, - { - alter: "alter table t rename to something_else", - isRename: true, - }, - { - alter: "alter table t drop column b, rename as something_else", - isRename: true, - }, - { - alter: "alter table t engine=innodb, rename as something_else", - isRename: true, - }, - { - alter: "alter table t rename as something_else, engine=innodb", - isRename: true, - }, - } - for _, tc := range tt { - t.Run(tc.alter, func(t *testing.T) { - parser := NewAlterTableParser() - alterStatement := alterTableStatement(t, tc.alter) - parser.AnalyzeAlter(alterStatement) - assert.Equal(t, tc.isRename, parser.isRenameTable) - }) - } -} diff --git a/go/vt/vttablet/onlineddl/vrepl/types.go b/go/vt/vttablet/onlineddl/vrepl/types.go deleted file mode 100644 index 0ca834ffdf0..00000000000 --- a/go/vt/vttablet/onlineddl/vrepl/types.go +++ /dev/null @@ -1,293 +0,0 @@ -/* - Original copyright by GitHub as follows. Additions by the Vitess authors as follows. -*/ -/* - Copyright 2016 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE -*/ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vrepl - -import ( - "fmt" - "reflect" - "strings" - - "vitess.io/vitess/go/vt/schemadiff" -) - -// ColumnType indicated some MySQL data types -type ColumnType int - -const ( - UnknownColumnType ColumnType = iota - TimestampColumnType - DateTimeColumnType - EnumColumnType - SetColumnType - MediumIntColumnType - JSONColumnType - FloatColumnType - DoubleColumnType - BinaryColumnType - StringColumnType - IntegerColumnType -) - -// Column represents a table column -type Column struct { - Name string - IsUnsigned bool - Charset string - Collation string - Type ColumnType - EnumValues string - EnumToTextConversion bool - DataType string // from COLUMN_TYPE column - - IsNullable bool - IsDefaultNull bool - - CharacterMaximumLength int64 - NumericPrecision int64 - NumericScale int64 - DateTimePrecision int64 - - // add Octet length for binary type, fix bytes with suffix "00" get clipped in mysql binlog. - // https://github.com/github/gh-ost/issues/909 - BinaryOctetLength uint64 -} - -// SetTypeIfUnknown will set a new column type only if the current type is unknown, otherwise silently skip -func (c *Column) SetTypeIfUnknown(t ColumnType) { - if c.Type == UnknownColumnType { - c.Type = t - } -} - -// HasDefault returns true if the column at all has a default value (possibly NULL) -func (c *Column) HasDefault() bool { - if c.IsDefaultNull && !c.IsNullable { - // based on INFORMATION_SCHEMA.COLUMNS, this is the indicator for a 'NOT NULL' column with no default value. - return false - } - return true -} - -// IsNumeric returns true if the column is of a numeric type -func (c *Column) IsNumeric() bool { - return c.NumericPrecision > 0 -} - -// IsIntegralType returns true if the column is some form of an integer -func (c *Column) IsIntegralType() bool { - return schemadiff.IsIntegralType(c.DataType) -} - -// IsFloatingPoint returns true if the column is of a floating point numeric type -func (c *Column) IsFloatingPoint() bool { - return c.Type == FloatColumnType || c.Type == DoubleColumnType -} - -// IsFloatingPoint returns true if the column is of a temporal type -func (c *Column) IsTemporal() bool { - return c.DateTimePrecision >= 0 -} - -// NewColumns creates a new column array from non empty names -func NewColumns(names []string) []Column { - result := []Column{} - for _, name := range names { - if name == "" { - continue - } - result = append(result, Column{Name: name}) - } - return result -} - -// ParseColumns creates a new column array fby parsing comma delimited names list -func ParseColumns(names string) []Column { - namesArray := strings.Split(names, ",") - return NewColumns(namesArray) -} - -// ColumnsMap maps a column name onto its ordinal position -type ColumnsMap map[string]int - -// NewEmptyColumnsMap creates an empty map -func NewEmptyColumnsMap() ColumnsMap { - columnsMap := make(map[string]int) - return ColumnsMap(columnsMap) -} - -// NewColumnsMap creates a column map based on ordered list of columns -func NewColumnsMap(orderedColumns []Column) ColumnsMap { - columnsMap := NewEmptyColumnsMap() - for i, column := range orderedColumns { - columnsMap[column.Name] = i - } - return columnsMap -} - -// ColumnList makes for a named list of columns -type ColumnList struct { - columns []Column - Ordinals ColumnsMap -} - -// NewColumnList creates an object given ordered list of column names -func NewColumnList(names []string) *ColumnList { - result := &ColumnList{ - columns: NewColumns(names), - } - result.Ordinals = NewColumnsMap(result.columns) - return result -} - -// ParseColumnList parses a comma delimited list of column names -func ParseColumnList(names string) *ColumnList { - result := &ColumnList{ - columns: ParseColumns(names), - } - result.Ordinals = NewColumnsMap(result.columns) - return result -} - -// Columns returns the list of columns -func (l *ColumnList) Columns() []Column { - return l.columns -} - -// Names returns list of column names -func (l *ColumnList) Names() []string { - names := make([]string, len(l.columns)) - for i := range l.columns { - names[i] = l.columns[i].Name - } - return names -} - -// GetColumn gets a column by name -func (l *ColumnList) GetColumn(columnName string) *Column { - if ordinal, ok := l.Ordinals[columnName]; ok { - return &l.columns[ordinal] - } - return nil -} - -// ColumnExists returns true if this column list has a column by a given name -func (l *ColumnList) ColumnExists(columnName string) bool { - _, ok := l.Ordinals[columnName] - return ok -} - -// String returns a comma separated list of column names -func (l *ColumnList) String() string { - return strings.Join(l.Names(), ",") -} - -// Equals checks for complete (deep) identities of columns, in order. -func (l *ColumnList) Equals(other *ColumnList) bool { - return reflect.DeepEqual(l.Columns, other.Columns) -} - -// EqualsByNames checks if the names in this list equals the names of another list, in order. Type is ignored. -func (l *ColumnList) EqualsByNames(other *ColumnList) bool { - return reflect.DeepEqual(l.Names(), other.Names()) -} - -// IsSubsetOf returns 'true' when column names of this list are a subset of -// another list, in arbitrary order (order agnostic) -func (l *ColumnList) IsSubsetOf(other *ColumnList) bool { - for _, column := range l.columns { - if _, exists := other.Ordinals[column.Name]; !exists { - return false - } - } - return true -} - -// Difference returns a (new copy) subset of this column list, consisting of all -// column NOT in given list. -// The result is never nil, even if the difference is empty -func (l *ColumnList) Difference(other *ColumnList) (diff *ColumnList) { - names := []string{} - for _, column := range l.columns { - if !other.ColumnExists(column.Name) { - names = append(names, column.Name) - } - } - return NewColumnList(names) -} - -// Len returns the length of this list -func (l *ColumnList) Len() int { - return len(l.columns) -} - -// MappedNamesColumnList returns a column list based on this list, with names possibly mapped by given map -func (l *ColumnList) MappedNamesColumnList(columnNamesMap map[string]string) *ColumnList { - names := l.Names() - for i := range names { - if mappedName, ok := columnNamesMap[names[i]]; ok { - names[i] = mappedName - } - } - return NewColumnList(names) -} - -// SetEnumToTextConversion tells this column list that an enum is converted to text -func (l *ColumnList) SetEnumToTextConversion(columnName string, enumValues string) { - l.GetColumn(columnName).EnumToTextConversion = true - l.GetColumn(columnName).EnumValues = enumValues -} - -// IsEnumToTextConversion tells whether an enum was converted to text -func (l *ColumnList) IsEnumToTextConversion(columnName string) bool { - return l.GetColumn(columnName).EnumToTextConversion -} - -// UniqueKey is the combination of a key's name and columns -type UniqueKey struct { - Name string - Columns ColumnList - HasNullable bool - HasSubpart bool - HasFloat bool - IsAutoIncrement bool -} - -// IsPrimary checks if this unique key is primary -func (k *UniqueKey) IsPrimary() bool { - return k.Name == "PRIMARY" -} - -// Len returns the length of this list -func (k *UniqueKey) Len() int { - return k.Columns.Len() -} - -// String returns a visual representation of this key -func (k *UniqueKey) String() string { - description := k.Name - if k.IsAutoIncrement { - description = fmt.Sprintf("%s (auto_increment)", description) - } - return fmt.Sprintf("%s: %s; has nullable: %+v", description, k.Columns.Names(), k.HasNullable) -} diff --git a/go/vt/vttablet/onlineddl/vrepl/types_test.go b/go/vt/vttablet/onlineddl/vrepl/types_test.go deleted file mode 100644 index d146d286d3a..00000000000 --- a/go/vt/vttablet/onlineddl/vrepl/types_test.go +++ /dev/null @@ -1,214 +0,0 @@ -/* - Copyright 2016 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE -*/ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vrepl - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestParseColumnList(t *testing.T) { - names := "id,category,max_len" - - columnList := ParseColumnList(names) - assert.Equal(t, columnList.Len(), 3) - assert.Equal(t, columnList.Names(), []string{"id", "category", "max_len"}) - assert.Equal(t, columnList.Ordinals["id"], 0) - assert.Equal(t, columnList.Ordinals["category"], 1) - assert.Equal(t, columnList.Ordinals["max_len"], 2) -} - -func TestGetColumn(t *testing.T) { - names := "id,category,max_len" - columnList := ParseColumnList(names) - { - column := columnList.GetColumn("category") - assert.NotNil(t, column) - assert.Equal(t, column.Name, "category") - } - { - column := columnList.GetColumn("no_such_column") - assert.True(t, column == nil) - } -} - -func TestIsSubsetOf(t *testing.T) { - tt := []struct { - columns1 *ColumnList - columns2 *ColumnList - expectSubset bool - }{ - { - columns1: ParseColumnList(""), - columns2: ParseColumnList("a,b,c"), - expectSubset: true, - }, - { - columns1: ParseColumnList("a,b,c"), - columns2: ParseColumnList("a,b,c"), - expectSubset: true, - }, - { - columns1: ParseColumnList("a,c"), - columns2: ParseColumnList("a,b,c"), - expectSubset: true, - }, - { - columns1: ParseColumnList("b,c"), - columns2: ParseColumnList("a,b,c"), - expectSubset: true, - }, - { - columns1: ParseColumnList("b"), - columns2: ParseColumnList("a,b,c"), - expectSubset: true, - }, - { - columns1: ParseColumnList(""), - columns2: ParseColumnList("a,b,c"), - expectSubset: true, - }, - { - columns1: ParseColumnList("a,d"), - columns2: ParseColumnList("a,b,c"), - expectSubset: false, - }, - { - columns1: ParseColumnList("a,b,c"), - columns2: ParseColumnList("a,b"), - expectSubset: false, - }, - { - columns1: ParseColumnList("a,b,c"), - columns2: ParseColumnList(""), - expectSubset: false, - }, - } - for _, tc := range tt { - name := fmt.Sprintf("%v:%v", tc.columns1.Names(), tc.columns2.Names()) - t.Run(name, func(t *testing.T) { - isSubset := tc.columns1.IsSubsetOf(tc.columns2) - assert.Equal(t, tc.expectSubset, isSubset) - }, - ) - } -} - -func TestDifference(t *testing.T) { - tt := []struct { - columns1 *ColumnList - columns2 *ColumnList - expect *ColumnList - }{ - { - columns1: ParseColumnList(""), - columns2: ParseColumnList("a,b,c"), - expect: ParseColumnList(""), - }, - { - columns1: ParseColumnList("a,b,c"), - columns2: ParseColumnList("a,b,c"), - expect: ParseColumnList(""), - }, - { - columns1: ParseColumnList("a,c"), - columns2: ParseColumnList("a,b,c"), - expect: ParseColumnList(""), - }, - { - columns1: ParseColumnList("b,c"), - columns2: ParseColumnList("a,b,c"), - expect: ParseColumnList(""), - }, - { - columns1: ParseColumnList("b"), - columns2: ParseColumnList("a,b,c"), - expect: ParseColumnList(""), - }, - { - columns1: ParseColumnList(""), - columns2: ParseColumnList("a,b,c"), - expect: ParseColumnList(""), - }, - { - columns1: ParseColumnList("a,d"), - columns2: ParseColumnList("a,b,c"), - expect: ParseColumnList("d"), - }, - { - columns1: ParseColumnList("a,b,c"), - columns2: ParseColumnList("a,b"), - expect: ParseColumnList("c"), - }, - { - columns1: ParseColumnList("a,b,c"), - columns2: ParseColumnList(""), - expect: ParseColumnList("a,b,c"), - }, - { - columns1: ParseColumnList("a,b,c"), - columns2: ParseColumnList("b,d,e"), - expect: ParseColumnList("a,c"), - }, - } - for _, tc := range tt { - name := fmt.Sprintf("%v:%v", tc.columns1.Names(), tc.columns2.Names()) - t.Run(name, func(t *testing.T) { - diff := tc.columns1.Difference(tc.columns2) - assert.Equal(t, tc.expect, diff) - }, - ) - } -} - -func TestMappedNamesColumnList(t *testing.T) { - tt := []struct { - columns *ColumnList - namesMap map[string]string - expected *ColumnList - }{ - { - columns: ParseColumnList("a,b,c"), - namesMap: map[string]string{}, - expected: ParseColumnList("a,b,c"), - }, - { - columns: ParseColumnList("a,b,c"), - namesMap: map[string]string{"x": "y"}, - expected: ParseColumnList("a,b,c"), - }, - { - columns: ParseColumnList("a,b,c"), - namesMap: map[string]string{"a": "x", "c": "y"}, - expected: ParseColumnList("x,b,y"), - }, - } - for _, tc := range tt { - name := fmt.Sprintf("%v:%v", tc.columns.Names(), tc.namesMap) - t.Run(name, func(t *testing.T) { - mappedNames := tc.columns.MappedNamesColumnList(tc.namesMap) - assert.Equal(t, tc.expected, mappedNames) - }, - ) - } -} diff --git a/go/vt/vttablet/onlineddl/vrepl/unique_key.go b/go/vt/vttablet/onlineddl/vrepl/unique_key.go deleted file mode 100644 index cc649b4ea37..00000000000 --- a/go/vt/vttablet/onlineddl/vrepl/unique_key.go +++ /dev/null @@ -1,184 +0,0 @@ -/* - Copyright 2016 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE -*/ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vrepl - -import ( - "strings" -) - -// UniqueKeyValidForIteration returns 'false' if we should not use this unique key as the main -// iteration key in vreplication. -func UniqueKeyValidForIteration(uniqueKey *UniqueKey) bool { - if uniqueKey.HasNullable { - // NULLable columns in a unique key means the set of values is not really unique (two identical rows with NULLs are allowed). - // Thus, we cannot use this unique key for iteration. - return false - } - if uniqueKey.HasSubpart { - // vreplication does not fully support indexes on column prefixes such as: - // UNIQUE KEY `name_idx` (`name`(15)) - // "HasSubpart" means some column covered by the index has a key length spec. - return false - } - if uniqueKey.HasFloat { - // float & double data types are imprecise and we cannot use them while iterating unique keys - return false - } - return true // good to go! -} - -// GetSharedUniqueKeys returns the unique keys shared between the source & target tables -func GetSharedUniqueKeys(sourceUniqueKeys, targetUniqueKeys [](*UniqueKey), columnRenameMap map[string]string) (chosenSourceUniqueKey, chosenTargetUniqueKey *UniqueKey) { - type ukPair struct{ source, target *UniqueKey } - var sharedUKPairs []*ukPair - - for _, sourceUniqueKey := range sourceUniqueKeys { - if !UniqueKeyValidForIteration(sourceUniqueKey) { - continue - } - for _, targetUniqueKey := range targetUniqueKeys { - if !UniqueKeyValidForIteration(targetUniqueKey) { - continue - } - uniqueKeyMatches := func() bool { - // Compare two unique keys - if sourceUniqueKey.Columns.Len() != targetUniqueKey.Columns.Len() { - return false - } - // Expect same columns, same order, potentially column name mapping - sourceUniqueKeyNames := sourceUniqueKey.Columns.Names() - targetUniqueKeyNames := targetUniqueKey.Columns.Names() - for i := range sourceUniqueKeyNames { - sourceColumnName := sourceUniqueKeyNames[i] - targetColumnName := targetUniqueKeyNames[i] - mappedSourceColumnName := sourceColumnName - if mapped, ok := columnRenameMap[sourceColumnName]; ok { - mappedSourceColumnName = mapped - } - if !strings.EqualFold(mappedSourceColumnName, targetColumnName) { - return false - } - } - return true - } - if uniqueKeyMatches() { - sharedUKPairs = append(sharedUKPairs, &ukPair{source: sourceUniqueKey, target: targetUniqueKey}) - } - } - } - // Now that we know what the shared unique keys are, let's find the "best" shared one. - // Source and target unique keys can have different name, even though they cover the exact same - // columns and in same order. - for _, pair := range sharedUKPairs { - if pair.source.HasNullable { - continue - } - if pair.target.HasNullable { - continue - } - return pair.source, pair.target - } - return nil, nil -} - -// SourceUniqueKeyAsOrMoreConstrainedThanTarget returns 'true' when sourceUniqueKey is at least as constrained as targetUniqueKey. -// "More constrained" means the uniqueness constraint is "stronger". Thus, if sourceUniqueKey is as-or-more constrained than targetUniqueKey, then -// rows valid under sourceUniqueKey must also be valid in targetUniqueKey. The opposite is not necessarily so: rows that are valid in targetUniqueKey -// may cause a unique key violation under sourceUniqueKey -func SourceUniqueKeyAsOrMoreConstrainedThanTarget(sourceUniqueKey, targetUniqueKey *UniqueKey, columnRenameMap map[string]string) bool { - // Compare two unique keys - if sourceUniqueKey.Columns.Len() > targetUniqueKey.Columns.Len() { - // source can't be more constrained if it covers *more* columns - return false - } - // we know that len(sourceUniqueKeyNames) <= len(targetUniqueKeyNames) - sourceUniqueKeyNames := sourceUniqueKey.Columns.Names() - targetUniqueKeyNames := targetUniqueKey.Columns.Names() - // source is more constrained than target if every column in source is also in target, order is immaterial - for i := range sourceUniqueKeyNames { - sourceColumnName := sourceUniqueKeyNames[i] - mappedSourceColumnName := sourceColumnName - if mapped, ok := columnRenameMap[sourceColumnName]; ok { - mappedSourceColumnName = mapped - } - columnFoundInTarget := func() bool { - for _, targetColumnName := range targetUniqueKeyNames { - if strings.EqualFold(mappedSourceColumnName, targetColumnName) { - return true - } - } - return false - } - if !columnFoundInTarget() { - return false - } - } - return true -} - -// AddedUniqueKeys returns the unique key constraints added in target. This does not necessarily mean that the unique key itself is new, -// rather that there's a new, stricter constraint on a set of columns, that didn't exist before. Example: -// -// before: unique key `my_key`(c1, c2, c3); after: unique key `my_key`(c1, c2) -// The constraint on (c1, c2) is new; and `my_key` in target table ("after") is considered a new key -// -// Order of columns is immaterial to uniqueness of column combination. -func AddedUniqueKeys(sourceUniqueKeys, targetUniqueKeys [](*UniqueKey), columnRenameMap map[string]string) (addedUKs [](*UniqueKey)) { - addedUKs = [](*UniqueKey){} - for _, targetUniqueKey := range targetUniqueKeys { - foundAsOrMoreConstrainingSourceKey := func() bool { - for _, sourceUniqueKey := range sourceUniqueKeys { - if SourceUniqueKeyAsOrMoreConstrainedThanTarget(sourceUniqueKey, targetUniqueKey, columnRenameMap) { - // target key does not add a new constraint - return true - } - } - return false - } - if !foundAsOrMoreConstrainingSourceKey() { - addedUKs = append(addedUKs, targetUniqueKey) - } - } - return addedUKs -} - -// RemovedUniqueKeys returns the list of unique key constraints _removed_ going from source to target. -func RemovedUniqueKeys(sourceUniqueKeys, targetUniqueKeys [](*UniqueKey), columnRenameMap map[string]string) (removedUKs [](*UniqueKey)) { - reverseColumnRenameMap := map[string]string{} - for k, v := range columnRenameMap { - reverseColumnRenameMap[v] = k - } - return AddedUniqueKeys(targetUniqueKeys, sourceUniqueKeys, reverseColumnRenameMap) -} - -// GetUniqueKeyCoveredByColumns returns the first unique key from given list, whose columns all appear -// in given column list. -func GetUniqueKeyCoveredByColumns(uniqueKeys [](*UniqueKey), columns *ColumnList) (chosenUniqueKey *UniqueKey) { - for _, uniqueKey := range uniqueKeys { - if !UniqueKeyValidForIteration(uniqueKey) { - continue - } - if uniqueKey.Columns.IsSubsetOf(columns) { - return uniqueKey - } - } - return nil -} diff --git a/go/vt/vttablet/onlineddl/vrepl/unique_key_test.go b/go/vt/vttablet/onlineddl/vrepl/unique_key_test.go deleted file mode 100644 index 3364c55a308..00000000000 --- a/go/vt/vttablet/onlineddl/vrepl/unique_key_test.go +++ /dev/null @@ -1,666 +0,0 @@ -/* - Copyright 2016 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE -*/ -/* -Copyright 2021 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package vrepl - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -var ( - columns1 = ParseColumnList("c1") - columns12 = ParseColumnList("c1,c2") - columns123 = ParseColumnList("c1,c2,c3") - columns2 = ParseColumnList("c2") - columns21 = ParseColumnList("c2,c1") - columns12A = ParseColumnList("c1,c2,ca") -) - -func TestGetSharedUniqueKeys(t *testing.T) { - tt := []struct { - name string - sourceUKs, targetUKs [](*UniqueKey) - renameMap map[string]string - expectSourceUK, expectTargetUK *UniqueKey - }{ - { - name: "empty", - sourceUKs: []*UniqueKey{}, - targetUKs: []*UniqueKey{}, - renameMap: map[string]string{}, - expectSourceUK: nil, - expectTargetUK: nil, - }, - { - name: "half empty", - sourceUKs: []*UniqueKey{ - {Name: "PRIMARY", Columns: *columns1}, - }, - targetUKs: []*UniqueKey{}, - renameMap: map[string]string{}, - expectSourceUK: nil, - expectTargetUK: nil, - }, - { - name: "single identical", - sourceUKs: []*UniqueKey{ - {Name: "PRIMARY", Columns: *columns1}, - }, - targetUKs: []*UniqueKey{ - {Name: "PRIMARY", Columns: *columns1}, - }, - renameMap: map[string]string{}, - expectSourceUK: &UniqueKey{Name: "PRIMARY", Columns: *columns1}, - expectTargetUK: &UniqueKey{Name: "PRIMARY", Columns: *columns1}, - }, - { - name: "single identical non pk", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - renameMap: map[string]string{}, - expectSourceUK: &UniqueKey{Name: "uidx", Columns: *columns1}, - expectTargetUK: &UniqueKey{Name: "uidx", Columns: *columns1}, - }, - { - name: "single identical, source is nullable", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1, HasNullable: true}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - renameMap: map[string]string{}, - expectSourceUK: nil, - expectTargetUK: nil, - }, - { - name: "single identical, target is nullable", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1, HasNullable: true}, - }, - renameMap: map[string]string{}, - expectSourceUK: nil, - expectTargetUK: nil, - }, - { - name: "single no shared", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns12}, - }, - renameMap: map[string]string{}, - expectSourceUK: nil, - expectTargetUK: nil, - }, - { - name: "single no shared different order", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns12}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns21}, - }, - renameMap: map[string]string{}, - expectSourceUK: nil, - expectTargetUK: nil, - }, - { - name: "single identical, source has FLOAT", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1, HasFloat: true}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - renameMap: map[string]string{}, - expectSourceUK: nil, - expectTargetUK: nil, - }, - { - name: "exact match", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - renameMap: map[string]string{}, - expectSourceUK: &UniqueKey{Name: "uidx123", Columns: *columns123}, - expectTargetUK: &UniqueKey{Name: "uidx123", Columns: *columns123}, - }, - { - name: "exact match from multiple options", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - {Name: "uidx123", Columns: *columns123}, - {Name: "uidx12", Columns: *columns12}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - renameMap: map[string]string{}, - expectSourceUK: &UniqueKey{Name: "uidx123", Columns: *columns123}, - expectTargetUK: &UniqueKey{Name: "uidx123", Columns: *columns123}, - }, - { - name: "exact match from multiple options reorder", - sourceUKs: []*UniqueKey{ - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx", Columns: *columns1}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns21}, - {Name: "uidx123", Columns: *columns123}, - {Name: "uidx12", Columns: *columns12}, - }, - renameMap: map[string]string{}, - expectSourceUK: &UniqueKey{Name: "uidx12", Columns: *columns12}, - expectTargetUK: &UniqueKey{Name: "uidx12", Columns: *columns12}, - }, - { - name: "match different names", - sourceUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx21", Columns: *columns21}, - {Name: "uidx123", Columns: *columns123}, - {Name: "uidxother", Columns: *columns12}, - }, - renameMap: map[string]string{}, - expectSourceUK: &UniqueKey{Name: "uidx12", Columns: *columns12}, - expectTargetUK: &UniqueKey{Name: "uidxother", Columns: *columns12}, - }, - { - name: "match different names, nullable", - sourceUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx21", Columns: *columns21}, - {Name: "uidx123other", Columns: *columns123}, - {Name: "uidx12", Columns: *columns12, HasNullable: true}, - }, - renameMap: map[string]string{}, - expectSourceUK: &UniqueKey{Name: "uidx123", Columns: *columns123}, - expectTargetUK: &UniqueKey{Name: "uidx123other", Columns: *columns123}, - }, - { - name: "match different column names", - sourceUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx21", Columns: *columns21}, - {Name: "uidx12A", Columns: *columns12A}, - }, - renameMap: map[string]string{"c3": "ca"}, - expectSourceUK: &UniqueKey{Name: "uidx123", Columns: *columns123}, - expectTargetUK: &UniqueKey{Name: "uidx12A", Columns: *columns12A}, - }, - { - // enforce mapping from c3 to ca; will not match c3<->c3 - name: "no match identical column names", - sourceUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx21", Columns: *columns21}, - {Name: "uidx123", Columns: *columns123}, - }, - renameMap: map[string]string{"c3": "ca"}, - expectSourceUK: nil, - expectTargetUK: nil, - }, - { - name: "no match different column names", - sourceUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx21", Columns: *columns21}, - {Name: "uidx12A", Columns: *columns12A}, - }, - renameMap: map[string]string{"c3": "cx"}, - expectSourceUK: nil, - expectTargetUK: nil, - }, - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - sourceUK, targetUK := GetSharedUniqueKeys(tc.sourceUKs, tc.targetUKs, tc.renameMap) - assert.Equal(t, tc.expectSourceUK, sourceUK) - assert.Equal(t, tc.expectTargetUK, targetUK) - }) - } -} - -func TestAddedUniqueKeys(t *testing.T) { - emptyUniqueKeys := []*UniqueKey{} - tt := []struct { - name string - sourceUKs, targetUKs [](*UniqueKey) - renameMap map[string]string - expectAddedUKs [](*UniqueKey) - expectRemovedUKs [](*UniqueKey) - }{ - { - name: "empty", - sourceUKs: emptyUniqueKeys, - targetUKs: emptyUniqueKeys, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: emptyUniqueKeys, - }, - { - name: "UK removed", - sourceUKs: []*UniqueKey{ - {Name: "PRIMARY", Columns: *columns1}, - }, - targetUKs: emptyUniqueKeys, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: []*UniqueKey{ - {Name: "PRIMARY", Columns: *columns1}, - }, - }, - { - name: "NULLable UK removed", - sourceUKs: []*UniqueKey{ - {Name: "PRIMARY", Columns: *columns1, HasNullable: true}, - }, - targetUKs: emptyUniqueKeys, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: []*UniqueKey{ - {Name: "PRIMARY", Columns: *columns1, HasNullable: true}, - }, - }, - { - name: "UK added", - sourceUKs: emptyUniqueKeys, - targetUKs: []*UniqueKey{ - {Name: "PRIMARY", Columns: *columns1}, - }, - renameMap: map[string]string{}, - expectAddedUKs: []*UniqueKey{ - {Name: "PRIMARY", Columns: *columns1}, - }, - expectRemovedUKs: emptyUniqueKeys, - }, - { - name: "single identical", - sourceUKs: []*UniqueKey{ - {Name: "PRIMARY", Columns: *columns1}, - }, - targetUKs: []*UniqueKey{ - {Name: "PRIMARY", Columns: *columns1}, - }, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: emptyUniqueKeys, - }, - { - name: "single identical non pk", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: emptyUniqueKeys, - }, - { - name: "single identical, source is nullable", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1, HasNullable: true}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: emptyUniqueKeys, - }, - { - name: "single identical, target is nullable", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1, HasNullable: true}, - }, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: emptyUniqueKeys, - }, - { - name: "expand columns: not considered added", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns12}, - }, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - }, - { - name: "expand columns, different order: not considered added", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns21}, - }, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - }, - { - name: "reduced columns: considered added", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns12}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - renameMap: map[string]string{}, - expectAddedUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - expectRemovedUKs: emptyUniqueKeys, - }, - { - name: "reduced columns, multiple: considered added", - sourceUKs: []*UniqueKey{ - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - {Name: "uidx2", Columns: *columns2}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - renameMap: map[string]string{}, - expectAddedUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx2", Columns: *columns2}, - }, - }, - { - name: "different order: not considered added", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns12}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns21}, - }, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: emptyUniqueKeys, - }, - { - name: "no match, different columns", - sourceUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx2", Columns: *columns2}, - }, - renameMap: map[string]string{}, - expectAddedUKs: []*UniqueKey{ - {Name: "uidx2", Columns: *columns2}, - }, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - }, - }, - { - name: "one match, one expand", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - }, - { - name: "exact match from multiple options", - sourceUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - {Name: "uidx123", Columns: *columns123}, - {Name: "uidx12", Columns: *columns12}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - }, - { - name: "exact match from multiple options reorder", - sourceUKs: []*UniqueKey{ - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx", Columns: *columns1}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns21}, - {Name: "uidx123", Columns: *columns123}, - {Name: "uidx12", Columns: *columns12}, - }, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx", Columns: *columns1}, - }, - }, - { - name: "match different names", - sourceUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx21", Columns: *columns21}, - {Name: "uidx123", Columns: *columns123}, - {Name: "uidxother", Columns: *columns12}, - }, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - }, - }, - { - name: "match different names, nullable", - sourceUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx21", Columns: *columns21}, - {Name: "uidx123other", Columns: *columns123}, - {Name: "uidx12", Columns: *columns12, HasNullable: true}, - }, - renameMap: map[string]string{}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - }, - }, - { - name: "match different column names, expand", - sourceUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx21", Columns: *columns21}, - {Name: "uidx12A", Columns: *columns12A}, - }, - renameMap: map[string]string{"c3": "ca"}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - }, - }, - { - name: "match different column names, no expand", - sourceUKs: []*UniqueKey{ - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx12A", Columns: *columns12A}, - }, - renameMap: map[string]string{"c3": "ca"}, - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: emptyUniqueKeys, - }, - { - // enforce mapping from c3 to ca; will not match c3<->c3 - name: "no match identical column names, expand", - sourceUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx21", Columns: *columns21}, - {Name: "uidx123", Columns: *columns123}, - }, - renameMap: map[string]string{"c3": "ca"}, - // 123 expands 12, so even though 3 is mapped to A, 123 is still not more constrained. - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - }, - }, - { - // enforce mapping from c3 to ca; will not match c3<->c3 - name: "no match identical column names, no expand", - sourceUKs: []*UniqueKey{ - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx123", Columns: *columns123}, - }, - renameMap: map[string]string{"c3": "ca"}, - expectAddedUKs: []*UniqueKey{ - {Name: "uidx123", Columns: *columns123}, - }, - expectRemovedUKs: emptyUniqueKeys, - }, - { - name: "no match for different column names, expand", - sourceUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - {Name: "uidx12", Columns: *columns12}, - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx21", Columns: *columns21}, - {Name: "uidx12A", Columns: *columns12A}, - }, - renameMap: map[string]string{"c3": "cx"}, - // 123 expands 12, so even though 3 is mapped to x, 123 is still not more constrained. - expectAddedUKs: emptyUniqueKeys, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx1", Columns: *columns1}, - }, - }, - { - name: "no match for different column names, no expand", - sourceUKs: []*UniqueKey{ - {Name: "uidx123", Columns: *columns123}, - }, - targetUKs: []*UniqueKey{ - {Name: "uidx12A", Columns: *columns12A}, - }, - renameMap: map[string]string{"c3": "cx"}, - expectAddedUKs: []*UniqueKey{ - {Name: "uidx12A", Columns: *columns12A}, - }, - expectRemovedUKs: []*UniqueKey{ - {Name: "uidx123", Columns: *columns123}, - }, - }, - } - - for _, tc := range tt { - t.Run(tc.name, func(t *testing.T) { - addedUKs := AddedUniqueKeys(tc.sourceUKs, tc.targetUKs, tc.renameMap) - assert.Equal(t, tc.expectAddedUKs, addedUKs) - removedUKs := RemovedUniqueKeys(tc.sourceUKs, tc.targetUKs, tc.renameMap) - assert.Equal(t, tc.expectRemovedUKs, removedUKs) - }) - } -} diff --git a/go/vt/vttablet/onlineddl/vrepl_test.go b/go/vt/vttablet/onlineddl/vrepl_test.go index ddb723ed7b7..b9875c3f6d2 100644 --- a/go/vt/vttablet/onlineddl/vrepl_test.go +++ b/go/vt/vttablet/onlineddl/vrepl_test.go @@ -1,6 +1,238 @@ /* - Copyright 2016 GitHub Inc. - See https://github.com/github/gh-ost/blob/master/LICENSE +Copyright 2024 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ package onlineddl + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "vitess.io/vitess/go/vt/schemadiff" + "vitess.io/vitess/go/vt/vtenv" +) + +func TestRevertible(t *testing.T) { + + type revertibleTestCase struct { + name string + fromSchema string + toSchema string + // expectProblems bool + removedForeignKeyNames string + removedUniqueKeyNames string + droppedNoDefaultColumnNames string + expandedColumnNames string + } + + var testCases = []revertibleTestCase{ + { + name: "identical schemas", + fromSchema: `id int primary key, i1 int not null default 0`, + toSchema: `id int primary key, i2 int not null default 0`, + }, + { + name: "different schemas, nothing to note", + fromSchema: `id int primary key, i1 int not null default 0, unique key i1_uidx(i1)`, + toSchema: `id int primary key, i1 int not null default 0, i2 int not null default 0, unique key i1_uidx(i1)`, + }, + { + name: "removed non-nullable unique key", + fromSchema: `id int primary key, i1 int not null default 0, unique key i1_uidx(i1)`, + toSchema: `id int primary key, i2 int not null default 0`, + removedUniqueKeyNames: `i1_uidx`, + }, + { + name: "removed nullable unique key", + fromSchema: `id int primary key, i1 int default null, unique key i1_uidx(i1)`, + toSchema: `id int primary key, i2 int default null`, + removedUniqueKeyNames: `i1_uidx`, + }, + { + name: "expanding unique key removes unique constraint", + fromSchema: `id int primary key, i1 int default null, unique key i1_uidx(i1)`, + toSchema: `id int primary key, i1 int default null, unique key i1_uidx(i1, id)`, + removedUniqueKeyNames: `i1_uidx`, + }, + { + name: "expanding unique key prefix removes unique constraint", + fromSchema: `id int primary key, v varchar(100) default null, unique key v_uidx(v(20))`, + toSchema: `id int primary key, v varchar(100) default null, unique key v_uidx(v(21))`, + removedUniqueKeyNames: `v_uidx`, + }, + { + name: "reducing unique key does not remove unique constraint", + fromSchema: `id int primary key, i1 int default null, unique key i1_uidx(i1, id)`, + toSchema: `id int primary key, i1 int default null, unique key i1_uidx(i1)`, + removedUniqueKeyNames: ``, + }, + { + name: "reducing unique key does not remove unique constraint", + fromSchema: `id int primary key, v varchar(100) default null, unique key v_uidx(v(21))`, + toSchema: `id int primary key, v varchar(100) default null, unique key v_uidx(v(20))`, + }, + { + name: "removed foreign key", + fromSchema: "id int primary key, i int, constraint some_fk_1 foreign key (i) references parent (id) on delete cascade", + toSchema: "id int primary key, i int", + removedForeignKeyNames: "some_fk_1", + }, + + { + name: "renamed foreign key", + fromSchema: "id int primary key, i int, constraint f1 foreign key (i) references parent (id) on delete cascade", + toSchema: "id int primary key, i int, constraint f2 foreign key (i) references parent (id) on delete cascade", + }, + { + name: "remove column without default", + fromSchema: `id int primary key, i1 int not null, i2 int not null default 0, i3 int default null`, + toSchema: `id int primary key, i4 int not null default 0`, + droppedNoDefaultColumnNames: `i1`, + }, + { + name: "expanded: nullable", + fromSchema: `id int primary key, i1 int not null, i2 int default null`, + toSchema: `id int primary key, i1 int default null, i2 int not null`, + expandedColumnNames: `i1`, + }, + { + name: "expanded: longer text", + fromSchema: `id int primary key, i1 int default null, v1 varchar(40) not null, v2 varchar(5), v3 varchar(3)`, + toSchema: `id int primary key, i1 int not null, v1 varchar(100) not null, v2 char(3), v3 char(5)`, + expandedColumnNames: `v1,v3`, + }, + { + name: "expanded: int numeric precision and scale", + fromSchema: `id int primary key, i1 int, i2 tinyint, i3 mediumint, i4 bigint`, + toSchema: `id int primary key, i1 int, i2 mediumint, i3 int, i4 tinyint`, + expandedColumnNames: `i2,i3`, + }, + { + name: "expanded: floating point", + fromSchema: `id int primary key, i1 int, n2 bigint, n3 bigint, n4 float, n5 double`, + toSchema: `id int primary key, i1 int, n2 float, n3 double, n4 double, n5 float`, + expandedColumnNames: `n2,n3,n4`, + }, + { + name: "expanded: decimal numeric precision and scale", + fromSchema: `id int primary key, i1 int, d1 decimal(10,2), d2 decimal (10,2), d3 decimal (10,2)`, + toSchema: `id int primary key, i1 int, d1 decimal(11,2), d2 decimal (9,1), d3 decimal (10,3)`, + expandedColumnNames: `d1,d3`, + }, + { + name: "expanded: signed, unsigned", + fromSchema: `id int primary key, i1 bigint signed, i2 int unsigned, i3 bigint unsigned`, + toSchema: `id int primary key, i1 int signed, i2 int signed, i3 int signed`, + expandedColumnNames: `i2,i3`, + }, + { + name: "expanded: signed, unsigned: range", + fromSchema: `id int primary key, i1 int signed, i2 bigint signed, i3 int signed`, + toSchema: `id int primary key, i1 int unsigned, i2 int unsigned, i3 bigint unsigned`, + expandedColumnNames: `i1,i3`, + }, + { + name: "expanded: datetime precision", + fromSchema: `id int primary key, dt1 datetime, ts1 timestamp, ti1 time, dt2 datetime(3), dt3 datetime(6), ts2 timestamp(3)`, + toSchema: `id int primary key, dt1 datetime(3), ts1 timestamp(6), ti1 time(3), dt2 datetime(6), dt3 datetime(3), ts2 timestamp`, + expandedColumnNames: `dt1,ts1,ti1,dt2`, + }, + { + name: "expanded: strange data type changes", + fromSchema: `id int primary key, dt1 datetime, ts1 timestamp, i1 int, d1 date, e1 enum('a', 'b')`, + toSchema: `id int primary key, dt1 char(32), ts1 varchar(32), i1 tinytext, d1 char(2), e1 varchar(2)`, + expandedColumnNames: `dt1,ts1,i1,d1,e1`, + }, + { + name: "expanded: temporal types", + fromSchema: `id int primary key, t1 time, t2 timestamp, t3 date, t4 datetime, t5 time, t6 date`, + toSchema: `id int primary key, t1 datetime, t2 datetime, t3 timestamp, t4 timestamp, t5 timestamp, t6 datetime`, + expandedColumnNames: `t1,t2,t3,t5,t6`, + }, + { + name: "expanded: character sets", + fromSchema: `id int primary key, c1 char(3) charset utf8, c2 char(3) charset utf8mb4, c3 char(3) charset ascii, c4 char(3) charset utf8mb4, c5 char(3) charset utf8, c6 char(3) charset latin1`, + toSchema: `id int primary key, c1 char(3) charset utf8mb4, c2 char(3) charset utf8, c3 char(3) charset utf8, c4 char(3) charset ascii, c5 char(3) charset utf8, c6 char(3) charset utf8mb4`, + expandedColumnNames: `c1,c3,c6`, + }, + { + name: "expanded: enum", + fromSchema: `id int primary key, e1 enum('a', 'b'), e2 enum('a', 'b'), e3 enum('a', 'b'), e4 enum('a', 'b'), e5 enum('a', 'b'), e6 enum('a', 'b'), e7 enum('a', 'b'), e8 enum('a', 'b')`, + toSchema: `id int primary key, e1 enum('a', 'b'), e2 enum('a'), e3 enum('a', 'b', 'c'), e4 enum('a', 'x'), e5 enum('a', 'x', 'b'), e6 enum('b'), e7 varchar(1), e8 tinyint`, + expandedColumnNames: `e3,e4,e5,e6,e7,e8`, + }, + { + name: "expanded: set", + fromSchema: `id int primary key, e1 set('a', 'b'), e2 set('a', 'b'), e3 set('a', 'b'), e4 set('a', 'b'), e5 set('a', 'b'), e6 set('a', 'b'), e7 set('a', 'b'), e8 set('a', 'b')`, + toSchema: `id int primary key, e1 set('a', 'b'), e2 set('a'), e3 set('a', 'b', 'c'), e4 set('a', 'x'), e5 set('a', 'x', 'b'), e6 set('b'), e7 varchar(1), e8 tinyint`, + expandedColumnNames: `e3,e4,e5,e6,e7,e8`, + }, + } + + var ( + createTableWrapper = `CREATE TABLE t (%s)` + ) + + senv := schemadiff.NewTestEnv() + venv := vtenv.NewTestEnv() + diffHints := &schemadiff.DiffHints{} + for _, tcase := range testCases { + t.Run(tcase.name, func(t *testing.T) { + tcase.fromSchema = fmt.Sprintf(createTableWrapper, tcase.fromSchema) + sourceTableEntity, err := schemadiff.NewCreateTableEntityFromSQL(senv, tcase.fromSchema) + require.NoError(t, err) + + tcase.toSchema = fmt.Sprintf(createTableWrapper, tcase.toSchema) + targetTableEntity, err := schemadiff.NewCreateTableEntityFromSQL(senv, tcase.toSchema) + require.NoError(t, err) + + diff, err := sourceTableEntity.TableDiff(targetTableEntity, diffHints) + require.NoError(t, err) + + v, err := NewVRepl( + venv, + "7cee19dd_354b_11eb_82cd_f875a4d24e90", + "ks", + "0", + "mydb", + sourceTableEntity.CreateTable, + targetTableEntity.CreateTable, + diff.AlterTable(), + false, + ) + require.NoError(t, err) + + err = v.analyzeAlter() + require.NoError(t, err) + err = v.analyzeTables() + require.NoError(t, err) + + toStringSlice := func(s string) []string { + if s == "" { + return []string{} + } + return strings.Split(s, ",") + } + assert.Equal(t, toStringSlice(tcase.removedForeignKeyNames), v.analysis.RemovedForeignKeyNames) + assert.Equal(t, toStringSlice(tcase.removedUniqueKeyNames), v.analysis.RemovedUniqueKeys.Names()) + assert.Equal(t, toStringSlice(tcase.droppedNoDefaultColumnNames), v.analysis.DroppedNoDefaultColumns.Names()) + assert.Equal(t, toStringSlice(tcase.expandedColumnNames), v.analysis.ExpandedColumns.Names()) + }) + } +}