From e7e9ec4297259ca28b8186efab6274f19753bc0d Mon Sep 17 00:00:00 2001 From: Hassan Shah Date: Tue, 29 Oct 2024 11:16:15 +0100 Subject: [PATCH] added mandatory name field to key constraint to be able to properly add and remove constraints using terraform --- catalog/resource_sql_table.go | 101 ++++++++++++++++++++++++-- catalog/resource_sql_table_test.go | 44 ++++++++++- internal/acceptance/sql_table_test.go | 4 + 3 files changed, 141 insertions(+), 8 deletions(-) diff --git a/catalog/resource_sql_table.go b/catalog/resource_sql_table.go index 89865c301..a9586f5d6 100644 --- a/catalog/resource_sql_table.go +++ b/catalog/resource_sql_table.go @@ -46,15 +46,20 @@ type SqlKeyConstraintInfo struct { } type SqlKeyConstraint interface { - getConstraint() string + getConstraintCreateTableStatement() string + getConstraintAlterTableCreateStatement() string + getConstraintAlterTableDropStatement() string + getConstraintName() string } type SqlPrimaryKeyConstraint struct { + Name string `json:"name"` PrimaryKey string `json:"primary_key"` Rely bool `json:"rely,omitempty" tf:"default:false"` } type SqlForeignKeyConstraint struct { + Name string `json:"name"` ReferencedKey string `json:"referenced_key"` ReferencedCatalog string `json:"referenced_catalog"` ReferencedSchema string `json:"referenced_schema"` @@ -62,17 +67,51 @@ type SqlForeignKeyConstraint struct { ReferencedForeignKey string `json:"referenced_foreign_key"` } -func (sqlKeyConstraint SqlPrimaryKeyConstraint) getConstraint() string { - var constraint = fmt.Sprintf("PRIMARY KEY (%s)", sqlKeyConstraint.PrimaryKey) +func (sqlKeyConstraint SqlPrimaryKeyConstraint) getConstraintName() string { + return fmt.Sprintf("`%s`", sqlKeyConstraint.Name) +} + +func (sqlKeyConstraint SqlForeignKeyConstraint) getConstraintName() string { + return fmt.Sprintf("`%s`", sqlKeyConstraint.Name) +} + +func (sqlKeyConstraint SqlPrimaryKeyConstraint) getConstraintCreateTableStatement() string { + var constraint = fmt.Sprintf( + "CONSTRAINT %s PRIMARY KEY (%s)", + sqlKeyConstraint.getConstraintName(), + sqlKeyConstraint.PrimaryKey) + if sqlKeyConstraint.Rely { + constraint += " RELY" + } + return constraint +} + +func (sqlKeyConstraint SqlForeignKeyConstraint) getConstraintCreateTableStatement() string { + return fmt.Sprintf( + "CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s.%s(%s)", + sqlKeyConstraint.getConstraintName(), + sqlKeyConstraint.ReferencedKey, + sqlKeyConstraint.ReferencedCatalog, + sqlKeyConstraint.ReferencedSchema, + sqlKeyConstraint.ReferencedTable, + sqlKeyConstraint.ReferencedForeignKey) +} + +func (sqlKeyConstraint SqlPrimaryKeyConstraint) getConstraintAlterTableCreateStatement() string { + var constraint = fmt.Sprintf( + "ADD CONSTRAINT %s PRIMARY KEY (%s)", + sqlKeyConstraint.getConstraintName(), + sqlKeyConstraint.PrimaryKey) if sqlKeyConstraint.Rely { constraint += " RELY" } return constraint } -func (sqlKeyConstraint SqlForeignKeyConstraint) getConstraint() string { +func (sqlKeyConstraint SqlForeignKeyConstraint) getConstraintAlterTableCreateStatement() string { return fmt.Sprintf( - "FOREIGN KEY (%s) REFERENCES %s.%s.%s(%s)", + "ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s.%s.%s(%s)", + sqlKeyConstraint.getConstraintName(), sqlKeyConstraint.ReferencedKey, sqlKeyConstraint.ReferencedCatalog, sqlKeyConstraint.ReferencedSchema, @@ -80,8 +119,20 @@ func (sqlKeyConstraint SqlForeignKeyConstraint) getConstraint() string { sqlKeyConstraint.ReferencedForeignKey) } +func (sqlKeyConstraint SqlPrimaryKeyConstraint) getConstraintAlterTableDropStatement() string { + return fmt.Sprintf( + "DROP CONSTRAINT %s", + sqlKeyConstraint.getConstraintName()) +} + +func (sqlKeyConstraint SqlForeignKeyConstraint) getConstraintAlterTableDropStatement() string { + return fmt.Sprintf( + "DROP CONSTRAINT %s", + sqlKeyConstraint.getConstraintName()) +} + func (ti *SqlTableInfo) serializeSqlKeyConstraintInfo(keyConstraint SqlKeyConstraintInfo) string { - return keyConstraint.SqlKeyConstraint.getConstraint() + return keyConstraint.SqlKeyConstraint.getConstraintCreateTableStatement() } func (ti *SqlTableInfo) serializeSqlKeyConstraintInfos() string { @@ -407,6 +458,43 @@ func (ti *SqlTableInfo) getStatementsForColumnDiffs(oldti *SqlTableInfo, stateme return statements } +func (ti *SqlTableInfo) addOrRemoveKeyConstraintStatements( + oldti *SqlTableInfo, + statements []string, + typestring string) []string { + nameToOldKeyConstraint := make(map[string]SqlKeyConstraintInfo) + nameToNewKeyConstraint := make(map[string]SqlKeyConstraintInfo) + for _, kci := range oldti.KeyConstraintInfos { + nameToOldKeyConstraint[kci.SqlKeyConstraint.getConstraintName()] = kci + } + for _, newKci := range ti.KeyConstraintInfos { + nameToNewKeyConstraint[newKci.SqlKeyConstraint.getConstraintName()] = newKci + } + + removeKeyConstraintStatements := make([]string, 0) + + for name, oldKci := range nameToOldKeyConstraint { + if _, exists := nameToNewKeyConstraint[name]; !exists { + // Remove old column if old column is no longer found in the config. + var oldKciDropStatement = oldKci.SqlKeyConstraint.getConstraintAlterTableDropStatement() + removeKeyConstraintStatements = append(statements, fmt.Sprintf("ALTER %s %s %s", typestring, ti.SQLFullName(), oldKciDropStatement)) + } + } + if len(removeKeyConstraintStatements) > 0 { + removeKeyConstraintStatementsStr := strings.Join(removeKeyConstraintStatements, ", ") + statements = append(statements, removeKeyConstraintStatementsStr) + } + + for _, newKci := range ti.KeyConstraintInfos { + if _, exists := nameToOldKeyConstraint[newKci.SqlKeyConstraint.getConstraintName()]; !exists { + // Add new column if new column is detected. + newKciStatement := newKci.SqlKeyConstraint.getConstraintAlterTableCreateStatement() + statements = append(statements, fmt.Sprintf("ALTER %s %s %s", typestring, ti.SQLFullName(), newKciStatement)) + } + } + return statements +} + func (ti *SqlTableInfo) addOrRemoveColumnStatements(oldti *SqlTableInfo, statements []string, typestring string) []string { nameToOldColumn := make(map[string]SqlColumnInfo) nameToNewColumn := make(map[string]SqlColumnInfo) @@ -510,6 +598,7 @@ func (ti *SqlTableInfo) diff(oldti *SqlTableInfo) ([]string, error) { } statements = ti.getStatementsForColumnDiffs(oldti, statements, typestring) + statements = ti.addOrRemoveKeyConstraintStatements(oldti, statements, typestring) return statements, nil } diff --git a/catalog/resource_sql_table_test.go b/catalog/resource_sql_table_test.go index e6f4b0b1f..cea8679de 100644 --- a/catalog/resource_sql_table_test.go +++ b/catalog/resource_sql_table_test.go @@ -87,6 +87,7 @@ func TestResourceSqlTableCreateStatement_PrimaryKeyConstraint(t *testing.T) { KeyConstraintInfos: []SqlKeyConstraintInfo{ { SqlKeyConstraint: SqlPrimaryKeyConstraint{ + Name: "id_pk", PrimaryKey: "id", Rely: true, }, @@ -97,7 +98,45 @@ func TestResourceSqlTableCreateStatement_PrimaryKeyConstraint(t *testing.T) { assert.Contains(t, stmt, "CREATE EXTERNAL TABLE `main`.`foo`.`bar`") assert.Contains(t, stmt, "USING DELTA") assert.Contains(t, stmt, "(`id` NOT NULL, `name` NOT NULL COMMENT 'a comment')") - assert.Contains(t, stmt, "(PRIMARY KEY (id) RELY)") + assert.Contains(t, stmt, "(CONSTRAINT `id_pk` PRIMARY KEY (id) RELY)") + assert.Contains(t, stmt, "LOCATION 's3://ext-main/foo/bar1' WITH (CREDENTIAL `somecred`)") + assert.Contains(t, stmt, "COMMENT 'terraform managed'") +} + +func TestResourceSqlTableAlterStatement_DropPrimaryKeyConstraint(t *testing.T) { + ti := &SqlTableInfo{ + Name: "bar", + CatalogName: "main", + SchemaName: "foo", + TableType: "EXTERNAL", + DataSourceFormat: "DELTA", + StorageLocation: "s3://ext-main/foo/bar1", + StorageCredentialName: "somecred", + Comment: "terraform managed", + ColumnInfos: []SqlColumnInfo{ + { + Name: "id", + }, + { + Name: "name", + Comment: "a comment", + }, + }, + KeyConstraintInfos: []SqlKeyConstraintInfo{ + { + SqlKeyConstraint: SqlPrimaryKeyConstraint{ + Name: "id_pk", + PrimaryKey: "id", + Rely: true, + }, + }, + }, + } + stmt := ti.buildTableCreateStatement() + assert.Contains(t, stmt, "CREATE EXTERNAL TABLE `main`.`foo`.`bar`") + assert.Contains(t, stmt, "USING DELTA") + assert.Contains(t, stmt, "(`id` NOT NULL, `name` NOT NULL COMMENT 'a comment')") + assert.Contains(t, stmt, "(CONSTRAINT `id_pk` PRIMARY KEY (id) RELY)") assert.Contains(t, stmt, "LOCATION 's3://ext-main/foo/bar1' WITH (CREDENTIAL `somecred`)") assert.Contains(t, stmt, "COMMENT 'terraform managed'") } @@ -124,6 +163,7 @@ func TestResourceSqlTableCreateStatement_ForeignKeyConstraint(t *testing.T) { KeyConstraintInfos: []SqlKeyConstraintInfo{ { SqlKeyConstraint: SqlForeignKeyConstraint{ + Name: "id_fk", ReferencedKey: "id", ReferencedCatalog: "bronze", ReferencedSchema: "biz", @@ -137,7 +177,7 @@ func TestResourceSqlTableCreateStatement_ForeignKeyConstraint(t *testing.T) { assert.Contains(t, stmt, "CREATE EXTERNAL TABLE `main`.`foo`.`bar`") assert.Contains(t, stmt, "USING DELTA") assert.Contains(t, stmt, "(`id` NOT NULL, `name` NOT NULL COMMENT 'a comment')") - assert.Contains(t, stmt, "(FOREIGN KEY (id) REFERENCES bronze.biz.transactions(transactionId)") + assert.Contains(t, stmt, "(CONSTRAINT `id_fk` FOREIGN KEY (id) REFERENCES bronze.biz.transactions(transactionId)") assert.Contains(t, stmt, "LOCATION 's3://ext-main/foo/bar1' WITH (CREDENTIAL `somecred`)") assert.Contains(t, stmt, "COMMENT 'terraform managed'") } diff --git a/internal/acceptance/sql_table_test.go b/internal/acceptance/sql_table_test.go index 69edccf6b..157b26c20 100644 --- a/internal/acceptance/sql_table_test.go +++ b/internal/acceptance/sql_table_test.go @@ -170,10 +170,12 @@ func TestUcAccResourceSqlTableWithPrimaryAndForeignKeyConstraints_Managed(t *tes type = "string" } key_constraint { + name = "table_pk" primary_key = "id" rely = "true" } key_constraint { + name = "external_id_fk" referenced_key = "external_id" referenced_catalog = "bronze" referenced_schema = "biz" @@ -213,10 +215,12 @@ func TestUcAccResourceSqlTableWithPrimaryAndForeignKeyConstraints_Managed(t *tes type = "string" } key_constraint { + name = "table_pk" primary_key = "id" rely = "true" } key_constraint { + name = "external_id_fk" referenced_key = "external_id" referenced_catalog = "bronze" referenced_schema = "biz"