diff --git a/composer/models/huggingface.py b/composer/models/huggingface.py index e0ad6fdf6d..9a84cda5c7 100644 --- a/composer/models/huggingface.py +++ b/composer/models/huggingface.py @@ -21,6 +21,7 @@ import torch from torchmetrics import Metric +from composer.devices import DeviceCPU from composer.models.base import ComposerModel from composer.utils import MissingConditionalImportError, dist, get_file, import_object, is_model_fsdp, safe_torch_load from composer.utils.warnings import VersionedDeprecationWarning @@ -590,6 +591,9 @@ def get_metrics(self, is_train: bool = False) -> Dict[str, Metric]: return metrics if metrics else {} def update_metric(self, batch: Any, outputs: Any, metric: Metric) -> Dict: + if metric.device.type == 'cpu': + self.labels = DeviceCPU().batch_to_device(self.labels) + if getattr(metric, 'needs_batch', False): metric_result = metric.update(batch=batch, outputs=outputs, labels=self.labels) else: diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 46843efa50..f4ec8bf8e4 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2204,12 +2204,21 @@ def close(self): self.engine.close() dist.barrier() - def _ensure_metrics_device_and_dtype(self, metrics: Dict[str, Metric]): + def _ensure_metrics_device_and_dtype( + self, + metrics: Dict[str, Metric], + ensure_cpu: bool = False, + ): for name, metric in metrics.items(): # Safety check to ensure the metric and data are on the same device. Normally not # needed because the metric is automatically on the same device as the model. # See https://torchmetrics.readthedocs.io/en/latest/pages/overview.html for details. - metrics[name] = self.state.device.module_to_device(metric) + + # Force all metrics to go on the CPU + if ensure_cpu: + metrics[name] = DeviceCPU().module_to_device(metric) + else: + metrics[name] = self.state.device.module_to_device(metric) if is_model_deepspeed(self.state.model): # HACK: DeepSpeed somehow manages to convert metric internal states to its own dtype. When # running with FP16, this tends to result in overflows. Let's assume FP32 is good enough. @@ -3202,7 +3211,11 @@ def _eval_loop( self.engine.run_event(Event.EVAL_START) - metrics = self._ensure_metrics_device_and_dtype(metrics) + # On MPS device we ensure the eval metrics are computed on CPU to avoid numerical errors + metrics = self._ensure_metrics_device_and_dtype( + metrics, + ensure_cpu=isinstance(self.state.device, DeviceMPS), + ) for metric in metrics.values(): metric.reset() @@ -3327,12 +3340,14 @@ def _eval_loop( outputs.append(v) else: outputs = self.state.outputs.cpu() + batch = DeviceCPU().batch_to_device(self.state.batch,) else: outputs = self.state.outputs + batch = self.state.batch for metric in metrics.values(): metric_outputs = self._original_model.update_metric( - self.state.batch, + batch, outputs, metric, )