diff --git a/brickflow/cli/projects.py b/brickflow/cli/projects.py index 47efb59c..49da37c7 100644 --- a/brickflow/cli/projects.py +++ b/brickflow/cli/projects.py @@ -91,7 +91,7 @@ def __init__( file_type: ConfigFileType = ConfigFileType.YAML, ) -> None: self.file_type = file_type - self._config_file: Path = Path(f"{config_file_name}.{file_type.value}") + self._config_file: Path = Path(config_file_name) self._brickflow_multi_project_config: BrickflowMultiRootProjectConfig self._brickflow_multi_project_config = ( self._load_config() @@ -122,10 +122,14 @@ def _load_config(self) -> BrickflowMultiRootProjectConfig: return config return BrickflowMultiRootProjectConfig(project_roots={}) - def _root_config_path(self, root: str) -> Path: + def _root_config_path( + self, + root: str, + config_file_type: ConfigFileType = BrickflowProjectConstants.DEFAULT_CONFIG_FILE_TYPE, + ) -> Path: root_file = ( f"{BrickflowProjectConstants.DEFAULT_MULTI_PROJECT_ROOT_FILE_NAME.value}." - f"{BrickflowProjectConstants.DEFAULT_CONFIG_FILE_TYPE.value}" + f"{config_file_type.value}" ) return self._config_file.parent / root / root_file @@ -140,7 +144,9 @@ def _load_roots(self) -> Dict[str, BrickflowRootProjectConfig]: ) root_dict = {} for root in roots: - with self._root_config_path(root).open("r", encoding="utf-8") as f: + with self._root_config_path(root, config_file_type=self.file_type).open( + "r", encoding="utf-8" + ) as f: root_dict[root] = BrickflowRootProjectConfig.parse_obj( yaml.safe_load(f.read()) ) @@ -234,7 +240,7 @@ def get_brickflow_root(current_path: Optional[Path] = None) -> Path: current_dir = Path(current_path or get_notebook_ws_path(ctx.dbutils) or os.getcwd()) potential_config_files = [ - f"{BrickflowProjectConstants.DEFAULT_MULTI_PROJECT_ROOT_FILE_NAME.value}.{cfg_type.value}" + f"{BrickflowProjectConstants.DEFAULT_MULTI_PROJECT_CONFIG_FILE_NAME.value}.{cfg_type.value}" for cfg_type in ConfigFileType ] potential_config_file_paths = [current_dir / p for p in potential_config_files] @@ -254,9 +260,9 @@ def get_brickflow_root(current_path: Optional[Path] = None) -> Path: brickflow_root_path = get_brickflow_root() -config_file_type = get_config_file_type(str(brickflow_root_path)) +cfg_file_type = get_config_file_type(str(brickflow_root_path)) multi_project_manager = MultiProjectManager( - config_file_name=str(brickflow_root_path), file_type=config_file_type + config_file_name=str(brickflow_root_path), file_type=cfg_file_type ) diff --git a/tests/cli/sample_yaml_project/.brickflow-project-root.yaml b/tests/cli/sample_yaml_project/.brickflow-project-root.yaml new file mode 100644 index 00000000..6c4de5ff --- /dev/null +++ b/tests/cli/sample_yaml_project/.brickflow-project-root.yaml @@ -0,0 +1,9 @@ +version: v1 +projects: + test_cli_project: + name: test_cli_project + brickflow_version: 1.2.1 + deployment_mode: bundle + enable_plugins: false + path_from_repo_root_to_project_root: some/test/path + path_project_root_to_workflows_dir: path/to/workflows \ No newline at end of file diff --git a/tests/cli/sample_yaml_project/brickflow-multi-project.yaml b/tests/cli/sample_yaml_project/brickflow-multi-project.yaml new file mode 100644 index 00000000..db955471 --- /dev/null +++ b/tests/cli/sample_yaml_project/brickflow-multi-project.yaml @@ -0,0 +1,4 @@ +version: v1 +project_roots: + test_cli_project: + root_yaml_rel_path: . diff --git a/tests/cli/sample_yml_project/.brickflow-project-root.yml b/tests/cli/sample_yml_project/.brickflow-project-root.yml new file mode 100644 index 00000000..6c4de5ff --- /dev/null +++ b/tests/cli/sample_yml_project/.brickflow-project-root.yml @@ -0,0 +1,9 @@ +version: v1 +projects: + test_cli_project: + name: test_cli_project + brickflow_version: 1.2.1 + deployment_mode: bundle + enable_plugins: false + path_from_repo_root_to_project_root: some/test/path + path_project_root_to_workflows_dir: path/to/workflows \ No newline at end of file diff --git a/tests/cli/sample_yml_project/brickflow-multi-project.yml b/tests/cli/sample_yml_project/brickflow-multi-project.yml new file mode 100644 index 00000000..db955471 --- /dev/null +++ b/tests/cli/sample_yml_project/brickflow-multi-project.yml @@ -0,0 +1,4 @@ +version: v1 +project_roots: + test_cli_project: + root_yaml_rel_path: . diff --git a/tests/test_brickflow.py b/tests/test_brickflow.py index 8a9f7fc4..b8ce8a81 100644 --- a/tests/test_brickflow.py +++ b/tests/test_brickflow.py @@ -1,4 +1,5 @@ # pylint: disable=unused-import +from brickflow import get_config_file_type, ConfigFileType def test_imports(): @@ -38,3 +39,18 @@ def test_imports(): print("All imports Succeeded") except ImportError as e: print(f"Import failed: {e}") + + +def test_get_config_type_yaml(): + actual = get_config_file_type("some/brickflow/root/.brickflow-project-root.yaml") + assert actual == ConfigFileType.YAML + + +def test_get_config_type_yml(): + actual = get_config_file_type("some/brickflow/root/.brickflow-project-root.yml") + assert actual == ConfigFileType.YML + + +def test_get_config_type_default(): + actual = get_config_file_type("some/brickflow/root/.brickflow-project-root.json") + assert actual == ConfigFileType.YAML