Skip to content

Commit

Permalink
Composer object store download retry (#3140)
Browse files Browse the repository at this point in the history
* retry

* up

* up

* up

* fix

* up

* a

* up

* up

* up

* up

* lint

* up

* up

* lint
  • Loading branch information
bigning authored Mar 22, 2024
1 parent f925ef0 commit 1740040
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
30 changes: 20 additions & 10 deletions composer/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
)
from composer.utils.misc import is_model_deepspeed, partial_format
from composer.utils.object_store import ObjectStore
from composer.utils.retrying import retry

if TYPE_CHECKING:
from composer.core import AlgorithmPass, State
Expand Down Expand Up @@ -188,6 +189,24 @@ def read_metadata(self) -> Metadata:
return super().read_metadata()


@retry(num_attempts=5)
def download_object_or_file(
object_name: str,
file_destination: Union[str, Path],
object_store: Union[ObjectStore, LoggerDestination],
):
if isinstance(object_store, ObjectStore):
object_store.download_object(
object_name=object_name,
filename=file_destination,
)
else:
object_store.download_file(
remote_file_name=object_name,
destination=str(file_destination),
)


# A subclass of FileSystemReaderWithValidation that downloads files from the object store before reading them from the local filesystem.
class DistCPObjectStoreReader(FileSystemReaderWithValidation):

Expand Down Expand Up @@ -262,16 +281,7 @@ def read_data(self, plan: LoadPlan, planner: LoadPlanner):
if not is_downloaded and not os.path.exists(file_destination):
log.debug(f'Downloading {relative_file_path} to {file_destination}.')
object_name = str(Path(self.source_path) / Path(relative_file_path))
if isinstance(self.object_store, ObjectStore):
self.object_store.download_object(
object_name=object_name,
filename=file_destination,
)
else:
self.object_store.download_file(
remote_file_name=object_name,
destination=file_destination,
)
download_object_or_file(object_name, file_destination, self.object_store)
log.debug(f'Finished downloading {relative_file_path} to {file_destination}.')
except Exception as e:
# PyTorch will capture any exception of this function,
Expand Down
4 changes: 4 additions & 0 deletions composer/utils/retrying.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import collections.abc
import functools
import logging
import random
import time
from typing import Any, Callable, Sequence, Type, TypeVar, Union, cast, overload
Expand All @@ -15,6 +16,8 @@

__all__ = ['retry']

log = logging.getLogger(__name__)


@overload
def retry(
Expand Down Expand Up @@ -86,6 +89,7 @@ def new_func(*args: Any, **kwargs: Any):
try:
return func(*args, **kwargs)
except exc_class as e:
log.debug(f'Attempt {i} failed. Exception type: {type(e)}, message: {str(e)}.')
if i + 1 == num_attempts:
raise e
else:
Expand Down

0 comments on commit 1740040

Please sign in to comment.