Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor maybe_create_object_store_from_uri #3679

Merged
merged 7 commits into from
Oct 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 35 additions & 23 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tempfile
import uuid
import warnings
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from urllib.parse import urlparse

import requests
Expand Down Expand Up @@ -393,13 +393,31 @@ def parse_uri(uri: str) -> tuple[str, str, str]:
return backend, bucket_name, path.lstrip('/')


# Dictionary mapping backend names to ObjectStore factory functions
BACKEND_TO_OBJECT_STORE_FACTORY: dict[str, Callable[[str, str], ObjectStore]] = {
's3':
lambda bucket, path: S3ObjectStore(bucket=bucket),
'gs':
lambda bucket, path: GCSObjectStore(bucket=bucket),
'oci':
lambda bucket, path: OCIObjectStore(bucket=bucket),
'azure':
lambda bucket, path: LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
),
}


def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
"""Automatically creates an :class:`composer.utils.ObjectStore` from supported URI formats.

Currently supported backends are ``s3://``, ``oci://``, and local paths (in which case ``None`` will be returned)

Args:
uri (str): The path to (maybe) create an :class:`composer.utils.ObjectStore` from
uri (str): The path to (maybe) create an :class:`composer.utils.ObjectStore` from.

Raises:
NotImplementedError: Raises when the URI format is not supported.
Expand All @@ -408,25 +426,15 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
Optional[ObjectStore]: Returns an :class:`composer.utils.ObjectStore` if the URI is of a supported format, otherwise None
"""
backend, bucket_name, path = parse_uri(uri)

# If backend is empty, assume local path and return None
if backend == '':
return None
if backend == 's3':
return S3ObjectStore(bucket=bucket_name)

# Handle special cases like WandB, MLFlow, etc.
elif backend == 'wandb':
raise NotImplementedError(
f'There is no implementation for WandB load_object_store via URI. Please use '
'WandBLogger',
)
elif backend == 'gs':
return GCSObjectStore(bucket=bucket_name)
elif backend == 'oci':
return OCIObjectStore(bucket=bucket_name)
elif backend == 'azure':
return LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket_name,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
f'There is no implementation for WandB load_object_store via URI. Please use WandBLogger',
)
elif backend == 'dbfs':
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
Expand All @@ -445,17 +453,21 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
# Create the object store for all other ranks
if dist.get_global_rank() != 0:
store = MLFlowObjectStore(path)

return store
else:
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return UCObjectStore(path=path)
else:
raise NotImplementedError(
f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported object stores',
)

# Check if backend is registered
elif backend in BACKEND_TO_OBJECT_STORE_FACTORY:
return BACKEND_TO_OBJECT_STORE_FACTORY[backend](bucket_name, path)

# If backend is unknown, raise NotImplementedError
raise NotImplementedError(
f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported object stores',
)


def maybe_create_remote_uploader_downloader_from_uri(
Expand Down
Loading