diff --git a/accessors/spanner/spannermetadataaccessor/clients/spanner_metadata_client.go b/accessors/spanner/spannermetadataaccessor/clients/spanner_metadata_client.go new file mode 100644 index 000000000..fbf5069d4 --- /dev/null +++ b/accessors/spanner/spannermetadataaccessor/clients/spanner_metadata_client.go @@ -0,0 +1,41 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spannermetadataclient + +import ( + "context" + "fmt" + "sync" + + sp "cloud.google.com/go/spanner" +) + +var once sync.Once +var spannermetadataClient *sp.Client + +var newClient = sp.NewClient + +func GetOrCreateClient(ctx context.Context, dbURI string) (*sp.Client, error) { + var err error + if spannermetadataClient == nil { + once.Do(func() { + spannermetadataClient, err = newClient(ctx, dbURI) + }) + if err != nil { + return nil, fmt.Errorf("failed to create spanner metadata database client: %v", err) + } + return spannermetadataClient, nil + } + return spannermetadataClient, nil +} diff --git a/accessors/spanner/spannermetadataaccessor/clients/spanner_metadata_client_test.go b/accessors/spanner/spannermetadataaccessor/clients/spanner_metadata_client_test.go new file mode 100644 index 000000000..8129a641f --- /dev/null +++ b/accessors/spanner/spannermetadataaccessor/clients/spanner_metadata_client_test.go @@ -0,0 +1,116 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spannermetadataclient + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + + sp "cloud.google.com/go/spanner" + "github.com/GoogleCloudPlatform/spanner-migration-tool/logger" + "github.com/stretchr/testify/assert" + "go.uber.org/zap" + "google.golang.org/api/option" +) + +func init() { + logger.Log = zap.NewNop() +} + +func TestMain(m *testing.M) { + res := m.Run() + os.Exit(res) +} + +func resetTest() { + spannermetadataClient = nil + once = sync.Once{} +} + +func TestGetOrCreateClient_Basic(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + client, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, client) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaSync(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + client, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, client) + assert.Nil(t, err) + // Explicitly set the client to nil. Running GetOrCreateClient should not create a + // new client since sync would already be executed. + spannermetadataClient = nil + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + client, err = GetOrCreateClient(ctx, "testURI") + assert.Nil(t, client) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_OnlyOnceViaIf(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return &sp.Client{}, nil + } + oldC, err := GetOrCreateClient(ctx, "testURI") + assert.NotNil(t, oldC) + assert.Nil(t, err) + + // Explicitly reset once. Running GetOrCreateClient should not create a + // new client the if condition should prevent it. + once = sync.Once{} + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + newC, err := GetOrCreateClient(ctx, "testURI") + assert.Equal(t, oldC, newC) + assert.Nil(t, err) +} + +func TestGetOrCreateClient_Error(t *testing.T) { + resetTest() + ctx := context.Background() + oldFunc := newClient + defer func() { newClient = oldFunc }() + + newClient = func(ctx context.Context, database string, opts ...option.ClientOption) (*sp.Client, error) { + return nil, fmt.Errorf("test error") + } + client, err := GetOrCreateClient(ctx, "testURI") + assert.Nil(t, client) + assert.NotNil(t, err) +} diff --git a/accessors/spanner/spannermetadataaccessor/spanner_metadata_accessor.go b/accessors/spanner/spannermetadataaccessor/spanner_metadata_accessor.go new file mode 100644 index 000000000..b562ca057 --- /dev/null +++ b/accessors/spanner/spannermetadataaccessor/spanner_metadata_accessor.go @@ -0,0 +1,68 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package spannermetadataaccessor + +import ( + "context" + "fmt" + + "cloud.google.com/go/spanner" + spannermetadataclient "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner/spannermetadataaccessor/clients" + "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" + "google.golang.org/api/iterator" +) + +type SpannerMetadataAccessor interface { + // IsSpannerSupportedDefaultStatement checks if the given statement is supported by Spanner. + IsSpannerSupportedDefaultStatement(SpProjectId string, SpInstanceId string, statement string, coldatatype string) bool +} + +type SpannerMetadataAccessorImpl struct{} + +func (spm *SpannerMetadataAccessorImpl) IsSpannerSupportedDefaultStatement(SpProjectId string, SpInstanceId string, statement string, coldatatype string) bool { + db := getSpannerMetadataDbUri(SpProjectId, SpInstanceId) + if SpProjectId == "" || SpInstanceId == "" { + return false + } + + ctx := context.Background() + spmClient, err := spannermetadataclient.GetOrCreateClient(ctx, db) + if err != nil { + return false + } + + if spmClient == nil { + return false + } + stmt := spanner.Statement{ + SQL: "SELECT CAST(" + statement + " AS " + coldatatype + ") AS statementValue", + } + iter := spmClient.Single().Query(ctx, stmt) + defer iter.Stop() + for { + _, err := iter.Next() + if err == iterator.Done { + return true + } + if err != nil { + return false + } + + } + +} + +func getSpannerMetadataDbUri(projectId string, instanceId string) string { + return fmt.Sprintf("projects/%s/instances/%s/databases/%s", projectId, instanceId, constants.METADATA_DB) +} diff --git a/cmd/data.go b/cmd/data.go index 0e8f036dc..b881b6585 100644 --- a/cmd/data.go +++ b/cmd/data.go @@ -204,7 +204,7 @@ func (cmd *DataCmd) Execute(ctx context.Context, f *flag.FlagSet, _ ...interface } // validateExistingDb validates that the existing spanner schema is in accordance with the one specified in the session file. -func validateExistingDb(ctx context.Context, spDialect, dbURI string, adminClient *database.DatabaseAdminClient, client *sp.Client, conv *internal.Conv) error { +func validateExistingDb(SpProjectId string, SpInstanceId string, ctx context.Context, spDialect, dbURI string, adminClient *database.DatabaseAdminClient, client *sp.Client, conv *internal.Conv) error { adminClientImpl, err := spanneradmin.NewAdminClientImpl(ctx) if err != nil { return err @@ -230,6 +230,8 @@ func validateExistingDb(ctx context.Context, spDialect, dbURI string, adminClien } spannerConv := internal.MakeConv() spannerConv.SpDialect = spDialect + spannerConv.SpProjectId = SpProjectId + spannerConv.SpInstanceId = SpInstanceId err = utils.ReadSpannerSchema(ctx, spannerConv, client) if err != nil { err = fmt.Errorf("can't read spanner schema: %v", err) diff --git a/cmd/utils.go b/cmd/utils.go index 58e80afce..4b8e16442 100644 --- a/cmd/utils.go +++ b/cmd/utils.go @@ -183,7 +183,7 @@ func migrateData(ctx context.Context, migrationProjectId string, targetProfile p err error ) if !sourceProfile.UseTargetSchema() { - err = validateExistingDb(ctx, conv.SpDialect, dbURI, adminClient, client, conv) + err = validateExistingDb(conv.SpProjectId, conv.SpInstanceId, ctx, conv.SpDialect, dbURI, adminClient, client, conv) if err != nil { err = fmt.Errorf("error while validating existing database: %v", err) return nil, err diff --git a/conversion/conversion.go b/conversion/conversion.go index fb3459813..c2bdb6fa0 100644 --- a/conversion/conversion.go +++ b/conversion/conversion.go @@ -78,7 +78,7 @@ func (ci *ConvImpl) SchemaConv(migrationProjectId string, sourceProfile profiles case constants.POSTGRES, constants.MYSQL, constants.DYNAMODB, constants.SQLSERVER, constants.ORACLE: return schemaFromSource.schemaFromDatabase(migrationProjectId, sourceProfile, targetProfile, &GetInfoImpl{}, &common.ProcessSchemaImpl{}) case constants.PGDUMP, constants.MYSQLDUMP: - return schemaFromSource.SchemaFromDump(sourceProfile.Driver, targetProfile.Conn.Sp.Dialect, ioHelper, &ProcessDumpByDialectImpl{}) + return schemaFromSource.SchemaFromDump(targetProfile.Conn.Sp.Project, targetProfile.Conn.Sp.Instance, sourceProfile.Driver, targetProfile.Conn.Sp.Dialect, ioHelper, &ProcessDumpByDialectImpl{}) default: return nil, fmt.Errorf("schema conversion for driver %s not supported", sourceProfile.Driver) } diff --git a/conversion/conversion_from_source.go b/conversion/conversion_from_source.go index 612af033b..4e200dbe2 100644 --- a/conversion/conversion_from_source.go +++ b/conversion/conversion_from_source.go @@ -38,7 +38,7 @@ import ( type SchemaFromSourceInterface interface { schemaFromDatabase(migrationProjectId string, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) - SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) + SchemaFromDump(SpProjectId string, SpInstanceId string, driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) } type SchemaFromSourceImpl struct{} @@ -54,6 +54,8 @@ type DataFromSourceImpl struct{} func (sads *SchemaFromSourceImpl) schemaFromDatabase(migrationProjectId string, sourceProfile profiles.SourceProfile, targetProfile profiles.TargetProfile, getInfo GetInfoInterface, processSchema common.ProcessSchemaInterface) (*internal.Conv, error) { conv := internal.MakeConv() conv.SpDialect = targetProfile.Conn.Sp.Dialect + conv.SpProjectId = targetProfile.Conn.Sp.Project + conv.SpInstanceId = targetProfile.Conn.Sp.Instance //handle fetching schema differently for sharded migrations, we only connect to the primary shard to //fetch the schema. We reuse the SourceProfileConnection object for this purpose. var infoSchema common.InfoSchema @@ -99,7 +101,7 @@ func (sads *SchemaFromSourceImpl) schemaFromDatabase(migrationProjectId string, return conv, processSchema.ProcessSchema(conv, infoSchema, common.DefaultWorkers, additionalSchemaAttributes, &common.SchemaToSpannerImpl{}, &common.UtilsOrderImpl{}, &common.InfoSchemaImpl{}) } -func (sads *SchemaFromSourceImpl) SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) { +func (sads *SchemaFromSourceImpl) SchemaFromDump(SpProjectId string, SpInstanceId string, driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) { f, n, err := getSeekable(ioHelper.In) if err != nil { utils.PrintSeekError(driver, err, ioHelper.Out) @@ -109,6 +111,8 @@ func (sads *SchemaFromSourceImpl) SchemaFromDump(driver string, spDialect string ioHelper.BytesRead = n conv := internal.MakeConv() conv.SpDialect = spDialect + conv.SpProjectId = SpProjectId + conv.SpInstanceId = SpInstanceId p := internal.NewProgress(n, "Generating schema", internal.Verbose(), false, int(internal.SchemaCreationInProgress)) r := internal.NewReader(bufio.NewReader(f), p) conv.SetSchemaMode() // Build schema and ignore data in dump. @@ -159,6 +163,8 @@ func (sads *DataFromSourceImpl) dataFromCSV(ctx context.Context, sourceProfile p return nil, fmt.Errorf("dbName is mandatory in target-profile for csv source") } conv.SpDialect = targetProfile.Conn.Sp.Dialect + conv.SpProjectId = targetProfile.Conn.Sp.Project + conv.SpInstanceId = targetProfile.Conn.Sp.Instance dialect, err := targetProfile.FetchTargetDialect(ctx) if err != nil { return nil, fmt.Errorf("could not fetch dialect: %v", err) diff --git a/conversion/mocks.go b/conversion/mocks.go index 64164678b..cfea860f2 100644 --- a/conversion/mocks.go +++ b/conversion/mocks.go @@ -53,8 +53,8 @@ func (msads *MockSchemaFromSource) schemaFromDatabase(migrationProjectId string, args := msads.Called(migrationProjectId, sourceProfile, targetProfile, getInfo, processSchema) return args.Get(0).(*internal.Conv), args.Error(1) } -func (msads *MockSchemaFromSource) SchemaFromDump(driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) { - args := msads.Called(driver, spDialect, ioHelper, processDump) +func (msads *MockSchemaFromSource) SchemaFromDump(SpProjectId string, SpInstanceId string, driver string, spDialect string, ioHelper *utils.IOStreams, processDump ProcessDumpByDialectInterface) (*internal.Conv, error) { + args := msads.Called(SpProjectId, SpInstanceId, driver, spDialect, ioHelper, processDump) return args.Get(0).(*internal.Conv), args.Error(1) } diff --git a/internal/convert.go b/internal/convert.go index b1cd1ba8d..a0086bc11 100644 --- a/internal/convert.go +++ b/internal/convert.go @@ -44,6 +44,8 @@ type Conv struct { Stats stats `json:"-"` TimezoneOffset string // Timezone offset for timestamp conversion. SpDialect string // The dialect of the spanner database to which Spanner migration tool is writing. + SpProjectId string // The projectId of the spanner database to which Spanner migration tool is writing. + SpInstanceId string // The instanceId of the spanner database to which Spanner migration tool is writing. UniquePKey map[string][]string // Maps Spanner table name to unique column name being used as primary key (if needed). Audit Audit `json:"-"` // Stores the audit information for the database conversion Rules []Rule // Stores applied rules during schema conversion diff --git a/sources/common/toddl.go b/sources/common/toddl.go index 16f4cf3ce..36105717f 100644 --- a/sources/common/toddl.go +++ b/sources/common/toddl.go @@ -35,6 +35,7 @@ import ( "strconv" "unicode" + spannermetadataaccessor "github.com/GoogleCloudPlatform/spanner-migration-tool/accessors/spanner/spannermetadataaccessor" "github.com/GoogleCloudPlatform/spanner-migration-tool/common/constants" "github.com/GoogleCloudPlatform/spanner-migration-tool/internal" "github.com/GoogleCloudPlatform/spanner-migration-tool/schema" @@ -146,13 +147,27 @@ func (ss *SchemaToSpannerImpl) SchemaToSpannerDDLHelper(conv *internal.Conv, tod columnLevelIssues[srcColId] = issues } + defaultVal := ddl.DefaultValue{ + IsPresent: false, + Value: "", + } + + if srcCol.DefaultValue.IsPresent { + spM := spannermetadataaccessor.SpannerMetadataAccessorImpl{} + defaultVal.IsPresent = spM.IsSpannerSupportedDefaultStatement(conv.SpProjectId, conv.SpInstanceId, srcCol.DefaultValue.Value, ty.Name) + } + if defaultVal.IsPresent { + defaultVal.Value = srcCol.DefaultValue.Value + } + spColDef[srcColId] = ddl.ColumnDef{ - Name: colName, - T: ty, - NotNull: isNotNull, - Comment: "From: " + quoteIfNeeded(srcCol.Name) + " " + srcCol.Type.Print(), - Id: srcColId, - AutoGen: *autoGenCol, + Name: colName, + T: ty, + NotNull: isNotNull, + Comment: "From: " + quoteIfNeeded(srcCol.Name) + " " + srcCol.Type.Print(), + Id: srcColId, + AutoGen: *autoGenCol, + DefaultValue: defaultVal, } if !checkIfColumnIsPartOfPK(srcColId, srcTable.PrimaryKeys) { totalNonKeyColumnSize += getColumnSize(ty.Name, ty.Len) diff --git a/spanner/ddl/ast.go b/spanner/ddl/ast.go index 9490a89e3..70226ffc5 100644 --- a/spanner/ddl/ast.go +++ b/spanner/ddl/ast.go @@ -180,12 +180,19 @@ func (ty Type) PGPrintColumnDefType() string { // column_def: // column_name type [NOT NULL] [options_def] type ColumnDef struct { - Name string - T Type - NotNull bool - Comment string - Id string - AutoGen AutoGenCol + Name string + T Type + NotNull bool + Comment string + Id string + AutoGen AutoGenCol + DefaultValue DefaultValue +} + +// DefaultValue represents Defaultvalue. +type DefaultValue struct { + IsPresent bool + Value string } // Config controls how AST nodes are printed (aka unparsed). @@ -247,6 +254,9 @@ func (cd ColumnDef) PrintColumnDef(c Config) (string, string) { if cd.NotNull { s += " NOT NULL " } + if cd.DefaultValue.IsPresent { + s += " DEFAULT (CAST(" + (cd.DefaultValue.Value) + " as " +(cd.T.Name)+ ")) " + } s += cd.AutoGen.PrintAutoGenCol() } return s, cd.Comment diff --git a/spanner/ddl/ast_test.go b/spanner/ddl/ast_test.go index d6e53f96b..bb805f039 100644 --- a/spanner/ddl/ast_test.go +++ b/spanner/ddl/ast_test.go @@ -78,6 +78,7 @@ func TestPrintColumnDef(t *testing.T) { {in: ColumnDef{Name: "col1", T: Type{Name: Int64, IsArray: true}}, expected: "col1 ARRAY"}, {in: ColumnDef{Name: "col1", T: Type{Name: Int64}, NotNull: true}, expected: "col1 INT64 NOT NULL "}, {in: ColumnDef{Name: "col1", T: Type{Name: Int64, IsArray: true}, NotNull: true}, expected: "col1 ARRAY NOT NULL "}, + {in: ColumnDef{Name: "col1", T: Type{Name: Int64}, DefaultValue: DefaultValue{IsPresent: true, Value: "4"}}, expected: "col1 INT64 DEFAULT (CAST(4 as INT64)) "}, {in: ColumnDef{Name: "col1", T: Type{Name: Int64}}, protectIds: true, expected: "`col1` INT64"}, } for _, tc := range tests { diff --git a/webv2/api/schema.go b/webv2/api/schema.go index f8b8ca3d3..ca27652ad 100644 --- a/webv2/api/schema.go +++ b/webv2/api/schema.go @@ -63,6 +63,8 @@ func ConvertSchemaSQL(w http.ResponseWriter, r *http.Request) { conv := internal.MakeConv() conv.SpDialect = sessionState.Dialect + conv.SpProjectId = sessionState.SpannerProjectId + conv.SpInstanceId = sessionState.SpannerInstanceID conv.IsSharded = sessionState.IsSharded var err error additionalSchemaAttributes := internal.AdditionalSchemaAttributes{ @@ -156,7 +158,10 @@ func ConvertSchemaDump(w http.ResponseWriter, r *http.Request) { sourceProfile, _ := profiles.NewSourceProfile("", dc.Config.Driver, &n) sourceProfile.Driver = dc.Config.Driver schemaFromSource := conversion.SchemaFromSourceImpl{} - conv, err := schemaFromSource.SchemaFromDump(sourceProfile.Driver, dc.SpannerDetails.Dialect, &utils.IOStreams{In: f, Out: os.Stdout}, &conversion.ProcessDumpByDialectImpl{}) + sessionState := session.GetSessionState() + SpProjectId := sessionState.SpannerProjectId + SpInstanceId := sessionState.SpannerInstanceID + conv, err := schemaFromSource.SchemaFromDump(SpProjectId, SpInstanceId, sourceProfile.Driver, dc.SpannerDetails.Dialect, &utils.IOStreams{In: f, Out: os.Stdout}, &conversion.ProcessDumpByDialectImpl{}) if err != nil { http.Error(w, fmt.Sprintf("Schema Conversion Error : %v", err), http.StatusNotFound) return @@ -169,7 +174,6 @@ func ConvertSchemaDump(w http.ResponseWriter, r *http.Request) { Dialect: dc.SpannerDetails.Dialect, } - sessionState := session.GetSessionState() sessionState.Conv.ConvLock.Lock() defer sessionState.Conv.ConvLock.Unlock() sessionState.Conv = conv