diff --git a/brickflow/codegen/databricks_bundle.py b/brickflow/codegen/databricks_bundle.py index b3d4ac54..82e8a655 100644 --- a/brickflow/codegen/databricks_bundle.py +++ b/brickflow/codegen/databricks_bundle.py @@ -144,7 +144,7 @@ class ResourceReference(BaseModel): class ImportBlock(BaseModel): to: str - id_: str + id_: Union[str, int] class ResourceAlreadyUsedByOtherProjectError(Exception): @@ -346,12 +346,7 @@ def transform(self, mutators: List[DatabricksBundleResourceMutator]) -> Resource class ImportManager: @staticmethod - def create_import_tf(env: str, import_blocks: List[ImportBlock]) -> None: - file_path = f".databricks/bundle/{env}/terraform/extra_imports.tf" - # Ensure directory structure exists - directory = os.path.dirname(file_path) - os.makedirs(directory, exist_ok=True) - # Create file + def create_import_str(import_blocks: List[ImportBlock]) -> str: import_statements = [] for import_block in import_blocks: _ilog.info("Reusing import for %s - %s", import_block.to, import_block.id_) @@ -361,10 +356,20 @@ def create_import_tf(env: str, import_blocks: List[ImportBlock]) -> None: f' id = "{import_block.id_}" \n' f"}}" ) + return "\n\n".join(import_statements) + + @staticmethod + def create_import_tf(env: str, import_blocks: List[ImportBlock]) -> None: + file_path = f".databricks/bundle/{env}/terraform/extra_imports.tf" + # Ensure directory structure exists + directory = os.path.dirname(file_path) + os.makedirs(directory, exist_ok=True) + import_content = ImportManager.create_import_str(import_blocks) + # Create file with open(file_path, "w", encoding="utf-8") as f: f.truncate() f.flush() - f.write("\n\n".join(import_statements)) + f.write(import_content) class DatabricksBundleCodegen(CodegenInterface): diff --git a/tests/codegen/test_databricks_bundle.py b/tests/codegen/test_databricks_bundle.py index fd9bfa0f..92948e37 100644 --- a/tests/codegen/test_databricks_bundle.py +++ b/tests/codegen/test_databricks_bundle.py @@ -24,6 +24,7 @@ DatabricksBundleImportMutator, DatabricksBundleCodegen, ImportBlock, + ImportManager, ) from brickflow.engine.project import Stage, Project from brickflow.engine.task import NotebookTask @@ -410,3 +411,21 @@ def test_mutators(self): assert ( jobs is not None and jobs[job_name].name == f"{fake_user_name}_{job_name}" ) + + def test_import_blocks(self): + # Databricks object ids are either strings or integers + block1 = ImportBlock(to="test", id_=1) + block2 = ImportBlock(to="test_2", id_="test") + blocks = [block1, block2] + expected_output = """import { + to = test + id = "1" +} + +import { + to = test_2 + id = "test" +}""" + assert ( + ImportManager.create_import_str(blocks).strip() == expected_output.strip() + ), "Import blocks are not equal"