Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor maybe_create_object_store_from_uri #3679

Merged
merged 7 commits into from
Oct 24, 2024
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 33 additions & 33 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tempfile
import uuid
import warnings
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from urllib.parse import urlparse

import requests
Expand Down Expand Up @@ -393,13 +393,29 @@ def parse_uri(uri: str) -> tuple[str, str, str]:
return backend, bucket_name, path.lstrip('/')


def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
"""Automatically creates an :class:`composer.utils.ObjectStore` from supported URI formats.
# Dictionary mapping backend names to ObjectStore factory functions
BACKEND_TO_OBJECT_STORE_FACTORY: dict[str, Callable[[str, str], ObjectStore]] = {
's3':
lambda bucket, path: S3ObjectStore(bucket=bucket),
'gs':
lambda bucket, path: GCSObjectStore(bucket=bucket),
'oci':
lambda bucket, path: OCIObjectStore(bucket=bucket),
'azure':
lambda bucket, path: LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
),
}

Currently supported backends are ``s3://``, ``oci://``, and local paths (in which case ``None`` will be returned)

def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
"""Automatically creates an ObjectStore from supported URI formats.

Args:
uri (str): The path to (maybe) create an :class:`composer.utils.ObjectStore` from
uri (str): The path to (maybe) create an :class:`composer.utils.ObjectStore` from.

Raises:
NotImplementedError: Raises when the URI format is not supported.
Expand All @@ -408,54 +424,38 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
Optional[ObjectStore]: Returns an :class:`composer.utils.ObjectStore` if the URI is of a supported format, otherwise None
"""
backend, bucket_name, path = parse_uri(uri)

# If backend is empty, assume local path and return None
if backend == '':
return None
if backend == 's3':
return S3ObjectStore(bucket=bucket_name)

# Handle special cases like WandB, MLFlow, etc.
elif backend == 'wandb':
raise NotImplementedError(
f'There is no implementation for WandB load_object_store via URI. Please use '
'WandBLogger',
)
elif backend == 'gs':
return GCSObjectStore(bucket=bucket_name)
elif backend == 'oci':
return OCIObjectStore(bucket=bucket_name)
elif backend == 'azure':
return LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket_name,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
f'There is no implementation for WandB load_object_store via URI. Please use WandBLogger',
)
elif backend == 'dbfs':
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
store = None
if dist.get_global_rank() == 0:
store = MLFlowObjectStore(path)

# The path may have had placeholders, so update it with the experiment/run IDs initialized by the store
irenedea marked this conversation as resolved.
Show resolved Hide resolved
path = store.get_dbfs_path(path)

# Broadcast the rank 0 updated path to all ranks for their own object stores
path_list = [path]
dist.broadcast_object_list(path_list, src=0)
path = path_list[0]

# Create the object store for all other ranks
if dist.get_global_rank() != 0:
store = MLFlowObjectStore(path)

return store
else:
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return UCObjectStore(path=path)
else:
raise NotImplementedError(
f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported object stores',
)

# Check if backend is registered
elif backend in BACKEND_TO_OBJECT_STORE_FACTORY:
return BACKEND_TO_OBJECT_STORE_FACTORY[backend](bucket_name, path)

# If backend is unknown, raise NotImplementedError
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI.')


def maybe_create_remote_uploader_downloader_from_uri(
Expand Down
Loading