diff --git a/go/test/endtoend/utils/mysql.go b/go/test/endtoend/utils/mysql.go index a522af2472e..e6eb693acab 100644 --- a/go/test/endtoend/utils/mysql.go +++ b/go/test/endtoend/utils/mysql.go @@ -22,6 +22,8 @@ import ( "fmt" "os" "path" + "regexp" + "strconv" "time" "github.com/stretchr/testify/assert" @@ -240,12 +242,44 @@ func compareVitessAndMySQLResults(t TestingT, query string, vtConn *mysql.Conn, return errors.New(errStr) } +// Parse the string representation of a type (i.e. "INT64") into a three elements slice. +// First element of the slice will contain the full expression, second element contains the +// type "INT" and the third element contains the size if there is any "64" or empty if we use +// "TIMESTAMP" for instance. +var checkFieldsRegExpr = regexp.MustCompile(`([a-zA-Z]*)(\d*)`) + func checkFields(t TestingT, columnName string, vtField, myField *querypb.Field) { t.Helper() - if vtField.Type != myField.Type { + + fail := func() { t.Errorf("for column %s field types do not match\nNot equal: \nMySQL: %v\nVitess: %v\n", columnName, myField.Type.String(), vtField.Type.String()) } + if vtField.Type != myField.Type { + vtMatches := checkFieldsRegExpr.FindStringSubmatch(vtField.Type.String()) + myMatches := checkFieldsRegExpr.FindStringSubmatch(myField.Type.String()) + + // Here we want to fail if we have totally different types for instance: "INT64" vs "TIMESTAMP" + // We do this by checking the length of the regexp slices and checking the second item of the slices (the real type i.e. "INT") + if len(vtMatches) != 3 || len(vtMatches) != len(myMatches) || vtMatches[1] != myMatches[1] { + fail() + return + } + vtVal, vtErr := strconv.Atoi(vtMatches[2]) + myVal, myErr := strconv.Atoi(myMatches[2]) + if vtErr != nil || myErr != nil { + fail() + return + } + + // Types the same now, however, if the size of the type is smaller on Vitess compared to MySQL + // we need to fail. We can allow superset but not the opposite. + if vtVal < myVal { + fail() + return + } + } + // starting in Vitess 20, decimal types are properly sized in their field information if BinaryIsAtLeastAtVersion(20, "vtgate") && vtField.Type == sqltypes.Decimal { if vtField.Decimals != myField.Decimals { diff --git a/go/test/endtoend/utils/mysql_test.go b/go/test/endtoend/utils/mysql_test.go index 59a5ea255ef..c29f4cef5b7 100644 --- a/go/test/endtoend/utils/mysql_test.go +++ b/go/test/endtoend/utils/mysql_test.go @@ -30,6 +30,7 @@ import ( "vitess.io/vitess/go/mysql/replication" "vitess.io/vitess/go/test/endtoend/cluster" "vitess.io/vitess/go/vt/mysqlctl" + querypb "vitess.io/vitess/go/vt/proto/query" ) var ( @@ -67,6 +68,47 @@ func TestMain(m *testing.M) { os.Exit(exitCode) } +func TestCheckFields(t *testing.T) { + createField := func(typ querypb.Type) *querypb.Field { + return &querypb.Field{ + Type: typ, + } + } + + cases := []struct { + fail bool + vtField querypb.Type + myField querypb.Type + }{ + { + vtField: querypb.Type_INT32, + myField: querypb.Type_INT32, + }, + { + vtField: querypb.Type_INT64, + myField: querypb.Type_INT32, + }, + { + fail: true, + vtField: querypb.Type_FLOAT32, + myField: querypb.Type_INT32, + }, + { + fail: true, + vtField: querypb.Type_TIMESTAMP, + myField: querypb.Type_TUPLE, + }, + } + + for _, c := range cases { + t.Run(fmt.Sprintf("%s_%s", c.vtField.String(), c.myField.String()), func(t *testing.T) { + tt := &testing.T{} + checkFields(tt, "col", createField(c.vtField), createField(c.myField)) + require.Equal(t, c.fail, tt.Failed()) + }) + } +} + func TestCreateMySQL(t *testing.T) { ctx := context.Background() conn, err := mysql.Connect(ctx, &mysqlParams)