diff --git a/composer/algorithms/seq_length_warmup/README.md b/composer/algorithms/seq_length_warmup/README.md index e26724225c..84aee2c697 100644 --- a/composer/algorithms/seq_length_warmup/README.md +++ b/composer/algorithms/seq_length_warmup/README.md @@ -44,6 +44,9 @@ def training_loop(model, train_loader): + ### Implementation Details diff --git a/composer/callbacks/memory_monitor.py b/composer/callbacks/memory_monitor.py index bd577c80d3..cc9341ef0d 100644 --- a/composer/callbacks/memory_monitor.py +++ b/composer/callbacks/memory_monitor.py @@ -90,7 +90,7 @@ def init(self, state: State, logger: Logger) -> None: # Not relying on `torch.cuda.is_available()` since the model could be on CPU. model_device = next(state.model.parameters()).device - if model_device.type != 'cuda': + if model_device.type not in ('cuda', 'meta'): warnings.warn(f'The memory monitor only works on CUDA devices, but the model is on {model_device.type}.') def after_train_batch(self, state: State, logger: Logger): diff --git a/composer/callbacks/nan_monitor.py b/composer/callbacks/nan_monitor.py index 16d3c47bf6..ce90cb147a 100644 --- a/composer/callbacks/nan_monitor.py +++ b/composer/callbacks/nan_monitor.py @@ -3,7 +3,7 @@ """Callback for catching loss NaNs.""" -from typing import Sequence +from typing import Dict, Sequence import torch @@ -24,5 +24,9 @@ def after_loss(self, state: State, logger: Logger): for loss in state.loss: if torch.isnan(loss).any(): raise RuntimeError('Train loss contains a NaN.') + elif isinstance(state.loss, Dict): + for k, v in state.loss.items(): + if torch.isnan(v).any(): + raise RuntimeError(f'Train loss {k} contains a NaN.') else: raise TypeError(f'Loss is of type {type(state.loss)}, but should be a tensor or a sequence of tensors') diff --git a/composer/core/state.py b/composer/core/state.py index dbdba40170..1ba5a193db 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -349,7 +349,7 @@ class State(Serializable): before the dataloader is evaluated. The :attr:`~Timestamp.epoch` attribute for this timestamp is always ``0``. device_train_microbatch_size (int): The size of each train microbatch per device. - loss (torch.Tensor | Sequence[torch.Tensor]): The most recently computed loss. + loss (torch.Tensor | Sequence[torch.Tensor] | Dict[Any, torch.Tensor]): The most recently computed loss. model (torch.nn.Module): The training model. .. note:: @@ -547,7 +547,7 @@ def __init__( # Set defaults for transient variables (to make pyright happy) self.batch: Any = None - self.loss: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor() + self.loss: Union[torch.Tensor, Sequence[torch.Tensor], Dict[Any, torch.Tensor]] = torch.Tensor() self.outputs: Union[torch.Tensor, Sequence[torch.Tensor]] = torch.Tensor() # These attributes will be serialized using .state_dict(), and loaded with .load_state_dict() diff --git a/composer/loggers/mosaicml_logger.py b/composer/loggers/mosaicml_logger.py index a23bfd9310..d14b3d81c5 100644 --- a/composer/loggers/mosaicml_logger.py +++ b/composer/loggers/mosaicml_logger.py @@ -12,6 +12,7 @@ import os import time import warnings +from concurrent.futures import wait from functools import reduce from typing import TYPE_CHECKING, Any, Dict, List, Optional @@ -57,22 +58,25 @@ class MosaicMLLogger(LoggerDestination): Example 2: ``ignore_keys = ["wall_clock/*"]`` would ignore all wall clock metrics. (default: ``None``) + ignore_exceptions: Flag to disable logging exceptions. Defaults to False. """ def __init__( self, log_interval: int = 60, ignore_keys: Optional[List[str]] = None, + ignore_exceptions: bool = False, ) -> None: self.log_interval = log_interval self.ignore_keys = ignore_keys + self.ignore_exceptions = ignore_exceptions self._enabled = dist.get_global_rank() == 0 if self._enabled: - self.allowed_fails_left = 3 self.time_last_logged = 0 self.train_dataloader_len = None - self.time_failed_count_adjusted = 0 self.buffered_metadata: Dict[str, Any] = {} + self._futures = [] + self.run_name = os.environ.get(RUN_NAME_ENV_VAR) if self.run_name is not None: log.info(f'Logging to mosaic run {self.run_name}') @@ -140,20 +144,25 @@ def _flush_metadata(self, force_flush: bool = False) -> None: """Flush buffered metadata to MosaicML if enough time has passed since last flush.""" if self._enabled and (time.time() - self.time_last_logged > self.log_interval or force_flush): try: - mcli.update_run_metadata(self.run_name, self.buffered_metadata) + f = mcli.update_run_metadata(self.run_name, self.buffered_metadata, future=True, protect=True) self.buffered_metadata = {} self.time_last_logged = time.time() - # If we have not failed in the last hour, increase the allowed fails. This increases - # robustness to transient network issues. - if time.time() - self.time_failed_count_adjusted > 3600 and self.allowed_fails_left < 3: - self.allowed_fails_left += 1 - self.time_failed_count_adjusted = time.time() - except Exception as e: - log.error(f'Failed to log metadata to Mosaic with error: {e}') - self.allowed_fails_left -= 1 - self.time_failed_count_adjusted = time.time() - if self.allowed_fails_left <= 0: + self._futures.append(f) + done, incomplete = wait(self._futures, timeout=0.01) + log.info(f'Logged {len(done)} metadata to MosaicML, waiting on {len(incomplete)}') + # Raise any exceptions + for f in done: + if f.exception() is not None: + raise f.exception() # type: ignore + self._futures = list(incomplete) + except Exception: + log.exception('Failed to log metadata to Mosaic') # Prints out full traceback + if self.ignore_exceptions: + log.info('Ignoring exception and disabling MosaicMLLogger.') self._enabled = False + else: + log.info('Raising exception. To ignore exceptions, set ignore_exceptions=True.') + raise def _get_training_progress_metrics(self, state: State) -> Dict[str, Any]: """Calculates training progress metrics. diff --git a/composer/trainer/dist_strategy.py b/composer/trainer/dist_strategy.py index 0ac6e3ccee..768089ec0c 100644 --- a/composer/trainer/dist_strategy.py +++ b/composer/trainer/dist_strategy.py @@ -133,9 +133,11 @@ def set_fsdp_default(fsdp_config: Dict[str, Any]): fsdp_config.setdefault('activation_checkpointing_reentrant', True) fsdp_config.setdefault('activation_cpu_offload', False) fsdp_config.setdefault('backward_prefetch', 'BACKWARD_POST') + fsdp_config.setdefault('backward_prefetch_limit', 1) fsdp_config.setdefault('cpu_offload', False) fsdp_config.setdefault('flatten_parameters', True) fsdp_config.setdefault('forward_prefetch', False) + fsdp_config.setdefault('forward_prefetch_limit', 1) fsdp_config.setdefault('ignored_modules', None) fsdp_config.setdefault('keep_low_precision_grads', False) fsdp_config.setdefault('limit_all_gathers', True) @@ -508,6 +510,24 @@ def _auto_wrap_policy_old(module: torch.nn.Module, recurse: bool, unwrapped_para **kwargs, ) + if hasattr(fsdp_obj, '_exec_order_data'): + if hasattr(fsdp_obj._exec_order_data, '_forward_prefetch_limit'): + fsdp_obj._exec_order_data._forward_prefetch_limit = fsdp_config['forward_prefetch_limit'] + else: + warnings.warn('FSDP._exec_order_data does not have attribute _forward_prefetch_limit ' + 'which is unexpected and will result in `forward_prefetch_limit` from FSDP ' + 'config being ignored. Please open an issue to Composer to report this.') + if hasattr(fsdp_obj._exec_order_data, '_backward_prefetch_limit'): + fsdp_obj._exec_order_data._backward_prefetch_limit = fsdp_config['backward_prefetch_limit'] + else: + warnings.warn('FSDP._exec_order_data does not have attribute _backward_prefetch_limit ' + 'which is unexpected and will result in `backward_prefetch_limit` from FSDP ' + 'config being ignored. Please open an issue to Composer to report this.') + else: + warnings.warn('FSDP does not have attribute _exec_order_data which is unexpected and will ' + 'result in `forward_prefetch_limit` and `backward_prefetch_limit` from FSDP ' + 'config being ignored. Please open an issue to Composer to report this.') + # Activation Checkpointing if activation_checkpointing or activation_cpu_offload: if not activation_checkpointing_reentrant: diff --git a/docs/source/trainer/logging.rst b/docs/source/trainer/logging.rst index 48ca8c47b2..7458f89ff6 100644 --- a/docs/source/trainer/logging.rst +++ b/docs/source/trainer/logging.rst @@ -35,6 +35,7 @@ and also saves them to the file os.environ["WANDB_MODE"] = "disabled" os.environ["COMET_API_KEY"] = "" + os.environ["MLFLOW_TRACKING_URI"] = "" .. testcode:: :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED diff --git a/setup.py b/setup.py index 78f8af7746..065104c75b 100644 --- a/setup.py +++ b/setup.py @@ -87,7 +87,7 @@ def package_files(prefix: str, directory: str, extension: str): 'py-cpuinfo>=8.0.0,<10', 'packaging>=21.3.0,<23', 'importlib-metadata>=5.0.0,<7', - 'mosaicml-cli>=0.5.8,<0.6', + 'mosaicml-cli>=0.5.25,<0.6', ] extra_deps = {} diff --git a/tests/fixtures/autouse_fixtures.py b/tests/fixtures/autouse_fixtures.py index fc836be679..f44a74c363 100644 --- a/tests/fixtures/autouse_fixtures.py +++ b/tests/fixtures/autouse_fixtures.py @@ -5,6 +5,7 @@ import logging import os import pathlib +from concurrent.futures import Future import mcli import pytest @@ -118,7 +119,9 @@ def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch): @pytest.fixture(autouse=True) def mapi_fixture(monkeypatch): # Composer auto-adds mosaicml logger when running on platform. Disable logging for tests. - mock_update = lambda *args, **kwargs: None + future_obj = Future() + future_obj.set_result(None) + mock_update = lambda *args, **kwargs: future_obj monkeypatch.setattr(mcli, 'update_run_metadata', mock_update) diff --git a/tests/loggers/test_mosaicml_logger.py b/tests/loggers/test_mosaicml_logger.py index 1258e47e21..5461829ab5 100644 --- a/tests/loggers/test_mosaicml_logger.py +++ b/tests/loggers/test_mosaicml_logger.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import json +from concurrent.futures import Future from typing import Type from unittest.mock import MagicMock @@ -23,10 +24,26 @@ class MockMAPI: - def __init__(self): + def __init__(self, simulate_exception: bool = False): self.run_metadata = {} - - def update_run_metadata(self, run_name, new_metadata): + self.simulate_exception = simulate_exception + + def update_run_metadata(self, run_name, new_metadata, future=False, protect=True): + if future: + # Simulate asynchronous behavior using Future + future_obj = Future() + try: + self._update_metadata(run_name, new_metadata) + future_obj.set_result(None) # Set a result to indicate completion + except Exception as e: + future_obj.set_exception(e) # Set an exception if something goes wrong + return future_obj + else: + self._update_metadata(run_name, new_metadata) + + def _update_metadata(self, run_name, new_metadata): + if self.simulate_exception: + raise RuntimeError('Simulated exception') if run_name not in self.run_metadata: self.run_metadata[run_name] = {} for k, v in new_metadata.items(): @@ -94,6 +111,30 @@ def test_logged_data_is_json_serializable(monkeypatch, callback_cls: Type[Callba assert len(mock_mapi.run_metadata.keys()) == 0 +@world_size(1, 2) +@pytest.mark.parametrize('ignore_exceptions', [True, False]) +def test_logged_data_exception_handling(monkeypatch, world_size: int, ignore_exceptions: bool): + """Test that exceptions in MAPI are raised properly.""" + mock_mapi = MockMAPI(simulate_exception=True) + monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata) + run_name = 'small_chungus' + monkeypatch.setenv('RUN_NAME', run_name) + + logger = MosaicMLLogger(ignore_exceptions=ignore_exceptions) + if dist.get_global_rank() != 0: + assert logger._enabled is False + logger._flush_metadata(force_flush=True) + assert logger._enabled is False + elif ignore_exceptions: + assert logger._enabled is True + logger._flush_metadata(force_flush=True) + assert logger._enabled is False + else: + with pytest.raises(RuntimeError, match='Simulated exception'): + assert logger._enabled is True + logger._flush_metadata(force_flush=True) + + def test_metric_partial_filtering(monkeypatch): mock_mapi = MockMAPI() monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata) diff --git a/tests/trainer/test_fsdp.py b/tests/trainer/test_fsdp.py index ac0559b9b3..ad16905a0c 100644 --- a/tests/trainer/test_fsdp.py +++ b/tests/trainer/test_fsdp.py @@ -9,18 +9,21 @@ from composer.models import ComposerClassifier from composer.trainer.trainer import Trainer from composer.utils import dist -from tests.common import EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel +from tests.common import (EmbeddedWeightTiedModel, RandomClassificationDataset, SimpleModel, SimpleWeightTiedModel, + world_size) @pytest.mark.parametrize('model', [SimpleWeightTiedModel, EmbeddedWeightTiedModel]) @pytest.mark.parametrize('mixed_precision', ['FULL', 'DEFAULT', 'PURE']) @pytest.mark.parametrize('device', ['cpu', 'meta']) @pytest.mark.parametrize('reentrant', [True, False]) -@pytest.mark.filterwarnings('ignore::UserWarning') +@world_size(2) @pytest.mark.gpu +@pytest.mark.filterwarnings('ignore:The passed in model appears to have tied weights.*:UserWarning') @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), reason='FSDP requires PyTorch 1.13 or higher') -def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision: str, device: str, reentrant: bool): +def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision: str, reentrant: bool, world_size: int, + device: str): """test FSDP device initialization for a simple model with weight tying and a model where two modules from separate submodules have weight tying applied. This test also covers both 'cpu' and 'meta' devices. This is because 'meta' will result in deferred initialization until FSDP is initialized @@ -62,15 +65,16 @@ def test_fsdp_device_initialization(model: ComposerClassifier, mixed_precision: @pytest.mark.parametrize('model', [SimpleModel]) @pytest.mark.parametrize('mixed_precision', ['FULL', 'DEFAULT', 'PURE']) @pytest.mark.gpu +@world_size(2) @pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), reason='FSDP requires PyTorch 1.13 or higher') -def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precision: 'str', device: str = 'meta'): +def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precision: 'str', world_size: int): """ This test is intended to test FSDP for meta initialization when there are attributes that are `None` and ensure we don't raise nasty UserWarnings. """ num_classes = 2 - model = model(num_features=1, num_classes=num_classes, device=device, bias=False) + model = model(num_features=1, num_classes=num_classes, device='meta', bias=False) dataset = RandomClassificationDataset(shape=(num_classes,), size=2, num_classes=num_classes) dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset)) optimizer = torch.optim.SGD(model.parameters(), lr=0.01) @@ -85,3 +89,31 @@ def test_fsdp_meta_initialization_none(model: ComposerClassifier, mixed_precisio }, max_duration='3ba', ) + + +@pytest.mark.parametrize('forward_prefetch_limit', [1, 2]) +@pytest.mark.parametrize('backward_prefetch_limit', [1, 2]) +@pytest.mark.gpu +@world_size(2) +@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.13.0'), + reason='FSDP requires PyTorch 1.13 or higher') +def test_fsdp_prefetch_limit(forward_prefetch_limit: int, backward_prefetch_limit: int, world_size: int): + model = SimpleModel() + model.fc1._fsdp_wrap = True + model.fc2._fsdp_wrap = True + dataset = RandomClassificationDataset(size=10) + dataloader = DataLoader(dataset, sampler=dist.get_sampler(dataset)) + optimizer = torch.optim.SGD(model.parameters(), lr=0.01) + + trainer = Trainer( + model=model, + optimizers=optimizer, + train_dataloader=dataloader, + fsdp_config={ + 'forward_prefetch_limit': forward_prefetch_limit, + 'backward_prefetch_limit': backward_prefetch_limit, + }, + max_duration='3ba', + ) + + trainer.fit()