Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor maybe_create_object_store_from_uri #3679

Merged
merged 7 commits into from
Oct 24, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 46 additions & 34 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import tempfile
import uuid
import warnings
from typing import TYPE_CHECKING, Any, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
from urllib.parse import urlparse

import requests
Expand Down Expand Up @@ -393,13 +393,41 @@ def parse_uri(uri: str) -> tuple[str, str, str]:
return backend, bucket_name, path.lstrip('/')


def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
"""Automatically creates an :class:`composer.utils.ObjectStore` from supported URI formats.
# Registry for object store creation functions
object_store_registry: dict[str, Callable[[str, str], ObjectStore]] = {}
irenedea marked this conversation as resolved.
Show resolved Hide resolved

Currently supported backends are ``s3://``, ``oci://``, and local paths (in which case ``None`` will be returned)

def register_object_store(backend: str, factory_func: Callable[[str, str], ObjectStore]):
"""Registers a new object store backend to the registry.

Args:
uri (str): The path to (maybe) create an :class:`composer.utils.ObjectStore` from
backend (str): The backend name (e.g., 's3', 'oci').
factory_func (Callable): A function that accepts bucket_name and path and returns an ObjectStore instance.
"""
object_store_registry[backend] = factory_func


# Register default object stores
register_object_store('s3', lambda bucket, path: S3ObjectStore(bucket=bucket))
register_object_store('gs', lambda bucket, path: GCSObjectStore(bucket=bucket))
register_object_store('oci', lambda bucket, path: OCIObjectStore(bucket=bucket))
register_object_store(
'azure',
lambda bucket,
path: LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
),
)


def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
"""Automatically creates an ObjectStore from supported URI formats.

Args:
uri (str): The path to (maybe) create an ObjectStore from.

Raises:
NotImplementedError: Raises when the URI format is not supported.
Expand All @@ -408,54 +436,38 @@ def maybe_create_object_store_from_uri(uri: str) -> Optional[ObjectStore]:
Optional[ObjectStore]: Returns an :class:`composer.utils.ObjectStore` if the URI is of a supported format, otherwise None
"""
backend, bucket_name, path = parse_uri(uri)

# If backend is empty, assume local path and return None
if backend == '':
return None
if backend == 's3':
return S3ObjectStore(bucket=bucket_name)
elif backend == 'wandb':

# Check if backend is registered
if backend in object_store_registry:
return object_store_registry[backend](bucket_name, path)

# Handle special cases like WandB, MLFlow, etc.
if backend == 'wandb':
raise NotImplementedError(
f'There is no implementation for WandB load_object_store via URI. Please use '
'WandBLogger',
)
elif backend == 'gs':
return GCSObjectStore(bucket=bucket_name)
elif backend == 'oci':
return OCIObjectStore(bucket=bucket_name)
elif backend == 'azure':
return LibcloudObjectStore(
provider='AZURE_BLOBS',
container=bucket_name,
key_environ='AZURE_ACCOUNT_NAME',
secret_environ='AZURE_ACCOUNT_ACCESS_KEY',
f'There is no implementation for WandB load_object_store via URI. Please use WandBLogger',
)
elif backend == 'dbfs':
if path.startswith(MLFLOW_DBFS_PATH_PREFIX):
store = None
if dist.get_global_rank() == 0:
store = MLFlowObjectStore(path)

# The path may have had placeholders, so update it with the experiment/run IDs initialized by the store
irenedea marked this conversation as resolved.
Show resolved Hide resolved
path = store.get_dbfs_path(path)

# Broadcast the rank 0 updated path to all ranks for their own object stores
path_list = [path]
dist.broadcast_object_list(path_list, src=0)
path = path_list[0]

# Create the object store for all other ranks
if dist.get_global_rank() != 0:
store = MLFlowObjectStore(path)

return store
else:
# validate if the path conforms to the requirements for UC volume paths
UCObjectStore.validate_path(path)
return UCObjectStore(path=path)
else:
raise NotImplementedError(
f'There is no implementation for the cloud backend {backend} via URI. Please use '
'one of the supported object stores',
)

# If backend is unknown, raise NotImplementedError
raise NotImplementedError(f'There is no implementation for the cloud backend {backend} via URI.')


def maybe_create_remote_uploader_downloader_from_uri(
Expand Down
Loading