diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py index 2d737a7fce..2c8fa25591 100644 --- a/composer/utils/file_helpers.py +++ b/composer/utils/file_helpers.py @@ -12,7 +12,7 @@ import tempfile import uuid import warnings -from typing import TYPE_CHECKING, Any, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from urllib.parse import urlparse import requests @@ -393,13 +393,31 @@ def parse_uri(uri: str) -> tuple[str, str, str]: return backend, bucket_name, path.lstrip('/') +# Dictionary mapping backend names to ObjectStore factory functions +BACKEND_TO_OBJECT_STORE_FACTORY: dict[str, Callable[[str, str], ObjectStore]] = { + 's3': + lambda bucket, path: S3ObjectStore(bucket=bucket), + 'gs': + lambda bucket, path: GCSObjectStore(bucket=bucket), + 'oci': + lambda bucket, path: OCIObjectStore(bucket=bucket), + 'azure': + lambda bucket, path: LibcloudObjectStore( + provider='AZURE_BLOBS', + container=bucket, + key_environ='AZURE_ACCOUNT_NAME', + secret_environ='AZURE_ACCOUNT_ACCESS_KEY', + ), +} + + def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]: """Automatically creates an :class:`composer.utils.ObjectStore` from supported URI formats. Currently supported backends are ``s3://``, ``oci://``, and local paths (in which case ``None`` will be returned) Args: - uri (str): The path to (maybe) create an :class:`composer.utils.ObjectStore` from + uri (str): The path to (maybe) create an :class:`composer.utils.ObjectStore` from. Raises: NotImplementedError: Raises when the URI format is not supported. @@ -408,25 +426,15 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]: Optional[ObjectStore]: Returns an :class:`composer.utils.ObjectStore` if the URI is of a supported format, otherwise None """ backend, bucket_name, path = parse_uri(uri) + + # If backend is empty, assume local path and return None if backend == '': return None - if backend == 's3': - return S3ObjectStore(bucket=bucket_name) + + # Handle special cases like WandB, MLFlow, etc. elif backend == 'wandb': raise NotImplementedError( - f'There is no implementation for WandB load_object_store via URI. Please use ' - 'WandBLogger', - ) - elif backend == 'gs': - return GCSObjectStore(bucket=bucket_name) - elif backend == 'oci': - return OCIObjectStore(bucket=bucket_name) - elif backend == 'azure': - return LibcloudObjectStore( - provider='AZURE_BLOBS', - container=bucket_name, - key_environ='AZURE_ACCOUNT_NAME', - secret_environ='AZURE_ACCOUNT_ACCESS_KEY', + f'There is no implementation for WandB load_object_store via URI. Please use WandBLogger', ) elif backend == 'dbfs': if path.startswith(MLFLOW_DBFS_PATH_PREFIX): @@ -445,17 +453,21 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]: # Create the object store for all other ranks if dist.get_global_rank() != 0: store = MLFlowObjectStore(path) - return store else: # validate if the path conforms to the requirements for UC volume paths UCObjectStore.validate_path(path) return UCObjectStore(path=path) - else: - raise NotImplementedError( - f'There is no implementation for the cloud backend {backend} via URI. Please use ' - 'one of the supported object stores', - ) + + # Check if backend is registered + elif backend in BACKEND_TO_OBJECT_STORE_FACTORY: + return BACKEND_TO_OBJECT_STORE_FACTORY[backend](bucket_name, path) + + # If backend is unknown, raise NotImplementedError + raise NotImplementedError( + f'There is no implementation for the cloud backend {backend} via URI. Please use ' + 'one of the supported object stores', + ) def maybe_create_remote_uploader_downloader_from_uri(