Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Composer object store download retry #3140

Merged
merged 17 commits into from
Mar 22, 2024
30 changes: 20 additions & 10 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from composer.utils.misc import is_model_deepspeed, partial_format
from composer.utils.object_store import ObjectStore
from composer.utils.retrying import retry

if TYPE_CHECKING:
from composer.core import AlgorithmPass, State
Expand Down Expand Up @@ -188,6 +189,24 @@ def read_metadata(self) -> Metadata:
return super().read_metadata()


@retry(num_attempts=5)
def download_object_or_file(
object_name: str,
file_destination: Union[str, Path],
object_store: Union[ObjectStore, LoggerDestination],
):
if isinstance(object_store, ObjectStore):
object_store.download_object(
object_name=object_name,
filename=file_destination,
)
else:
object_store.download_file(
remote_file_name=object_name,
destination=str(file_destination),
)
bigning marked this conversation as resolved.
Show resolved Hide resolved


# A subclass of FileSystemReaderWithValidation that downloads files from the object store before reading them from the local filesystem.
class DistCPObjectStoreReader(FileSystemReaderWithValidation):

Expand Down Expand Up @@ -262,16 +281,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
if not is_downloaded and not os.path.exists(file_destination):
log.debug(f'Downloading {relative_file_path} to {file_destination}.')
object_name = str(Path(self.source_path) / Path(relative_file_path))
if isinstance(self.object_store, ObjectStore):
self.object_store.download_object(
object_name=object_name,
filename=file_destination,
)
else:
self.object_store.download_file(
remote_file_name=object_name,
destination=file_destination,
)
download_object_or_file(object_name, file_destination, self.object_store)
log.debug(f'Finished downloading {relative_file_path} to {file_destination}.')
except Exception as e:
# PyTorch will capture any exception of this function,
Expand Down
4 changes: 4 additions & 0 deletions composer/utils/retrying.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import collections.abc
import functools
import logging
import random
import time
from typing import Any, Callable, Sequence, Type, TypeVar, Union, cast, overload
Expand All @@ -15,6 +16,8 @@

__all__ = ['retry']

log = logging.getLogger(__name__)


@overload
def retry(
Expand Down Expand Up @@ -86,6 +89,7 @@ def new_func(*args: Any, **kwargs: Any):
try:
return func(*args, **kwargs)
except exc_class as e:
log.debug(f'Attempt {i} failed. Exception type: {type(e)}, message: {str(e)}.')
if i + 1 == num_attempts:
raise e
else:
Expand Down
Loading