Skip to content

Commit

Permalink
Revert "Improvements to NeptuneLogger (#3085)" (#3111)
Browse files Browse the repository at this point in the history
This reverts commit b63b263.
  • Loading branch information
mvpatel2000 authored Mar 13, 2024
1 parent b63b263 commit 46ffa82
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 228 deletions.
56 changes: 20 additions & 36 deletions composer/callbacks/oom_observer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@
# SPDX-License-Identifier: Apache-2.0

"""Generate a memory snapshot during an OutOfMemory exception."""
import dataclasses

import logging
import os
import pickle
import warnings
from dataclasses import dataclass
from typing import List, Optional
from typing import Optional

import torch.cuda
from packaging import version

from composer import State
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, parse_uri
Expand All @@ -22,29 +22,6 @@
__all__ = ['OOMObserver']


@dataclass(frozen=True)
class SnapshotFileNameConfig:
"""Configuration for the file names of the memory snapshot visualizations."""
snapshot_file: str
trace_plot_file: str
segment_plot_file: str
segment_flamegraph_file: str
memory_flamegraph_file: str

@classmethod
def from_file_name(cls, filename: str) -> 'SnapshotFileNameConfig':
return cls(
snapshot_file=filename + '_snapshot.pickle',
trace_plot_file=filename + '_trace_plot.html',
segment_plot_file=filename + '_segment_plot.html',
segment_flamegraph_file=filename + '_segment_flamegraph.svg',
memory_flamegraph_file=filename + '_memory_flamegraph.svg',
)

def list_filenames(self) -> List[str]:
return [getattr(self, field.name) for field in dataclasses.fields(self)]


class OOMObserver(Callback):
"""Generate visualizations of the state of allocated memory during an OutOfMemory exception.
Expand Down Expand Up @@ -117,8 +94,6 @@ def __init__(
self._enabled = False
warnings.warn('OOMObserver is supported after PyTorch 2.1.0. Disabling OOMObserver callback.')

self.filename_config: Optional[SnapshotFileNameConfig] = None

def init(self, state: State, logger: Logger) -> None:
if not self._enabled:
return
Expand Down Expand Up @@ -148,7 +123,11 @@ def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int):
)

try:
self.filename_config = SnapshotFileNameConfig.from_file_name(filename)
snapshot_file = filename + '_snapshot.pickle'
trace_plot_file = filename + '_trace_plot.html'
segment_plot_file = filename + '_segment_plot.html'
segment_flamegraph_file = filename + '_segment_flamegraph.svg'
memory_flamegraph_file = filename + '_memory_flamegraph.svg'
log.info(f'Dumping OOMObserver visualizations')

snapshot = torch.cuda.memory._snapshot()
Expand All @@ -157,26 +136,31 @@ def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int):
log.info(f'No allocation is recorded in memory snapshot)')
return

with open(self.filename_config.snapshot_file, 'wb') as fd:
with open(snapshot_file, 'wb') as fd:
pickle.dump(snapshot, fd)

with open(self.filename_config.trace_plot_file, 'w+') as fd:
with open(trace_plot_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.trace_plot(snapshot)) # type: ignore

with open(self.filename_config.segment_plot_file, 'w+') as fd:
with open(segment_plot_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.segment_plot(snapshot)) # type: ignore

with open(self.filename_config.segment_flamegraph_file, 'w+') as fd:
with open(segment_flamegraph_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.segments(snapshot)) # type: ignore

with open(self.filename_config.memory_flamegraph_file, 'w+') as fd:
with open(memory_flamegraph_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.memory(snapshot)) # type: ignore

log.info(f'Saved memory visualizations to local files with prefix = {filename} during OOM')

if self.remote_path_in_bucket is not None:

for f in self.filename_config.list_filenames():
for f in [
snapshot_file,
trace_plot_file,
segment_plot_file,
segment_flamegraph_file,
memory_flamegraph_file,
]:
base_file_name = os.path.basename(f)
remote_file_name = os.path.join(self.remote_path_in_bucket, base_file_name)
remote_file_name = remote_file_name.lstrip('/') # remove leading slashes
Expand Down
142 changes: 43 additions & 99 deletions composer/loggers/neptune_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pathlib
import warnings
from functools import partial
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Set, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Union

import numpy as np
import torch
Expand All @@ -19,18 +19,14 @@
from composer.utils import MissingConditionalImportError, dist

if TYPE_CHECKING:
from composer import Callback, Logger
from composer.callbacks import OOMObserver
from composer import Logger
from composer.core import State

NEPTUNE_MODE_TYPE = Literal['async', 'sync', 'offline', 'read-only', 'debug']


class NeptuneLogger(LoggerDestination):
"""Log to `neptune.ai <https://neptune.ai/>`_.
For instructions, see the
`integration guide <https://docs.neptune.ai/integrations/mosaicml_composer/>`_.
For more, see the [Neptune-Composer integration guide](https://docs.neptune.ai/integrations/composer/).
Args:
project (str, optional): The name of your Neptune project,
Expand All @@ -40,20 +36,20 @@ class NeptuneLogger(LoggerDestination):
You can leave out this argument if you save your token to the
``NEPTUNE_API_TOKEN`` environment variable (recommended).
You can find your API token in the user menu of the Neptune web app.
rank_zero_only (bool): Whether to log only on the rank-zero process (default: ``True``).
upload_artifacts (bool, optional): Deprecated. See ``upload_checkpoints``.
upload_checkpoints (bool): Whether the logger should upload checkpoints to Neptune
rank_zero_only (bool, optional): Whether to log only on the rank-zero process.
(default: ``True``).
upload_artifacts (bool, optional): Whether the logger should upload artifacts to Neptune.
(default: ``False``).
base_namespace (str, optional): The name of the base namespace where the metadata
is logged (default: "training").
base_namespace (str, optional): The name of the base namespace to log the metadata to.
(default: "training").
neptune_kwargs (Dict[str, Any], optional): Any additional keyword arguments to the
``neptune.init_run()`` function. For options, see the
`Run API reference <https://docs.neptune.ai/api/neptune/#init_run>`_.
`Run API reference <https://docs.neptune.ai/api/neptune/#init_run>`_ in the
Neptune docs.
"""
metric_namespace = 'metrics'
hyperparam_namespace = 'hyperparameters'
trace_namespace = 'traces'
oom_snaphot_namespace = 'oom_snapshots'
integration_version_key = 'source_code/integrations/neptune-MosaicML'

def __init__(
Expand All @@ -62,10 +58,8 @@ def __init__(
project: Optional[str] = None,
api_token: Optional[str] = None,
rank_zero_only: bool = True,
upload_artifacts: Optional[bool] = None,
upload_checkpoints: bool = False,
upload_artifacts: bool = False,
base_namespace: str = 'training',
mode: Optional[NEPTUNE_MODE_TYPE] = None,
**neptune_kwargs,
) -> None:
try:
Expand All @@ -80,8 +74,7 @@ def __init__(
verify_type('project', project, (str, type(None)))
verify_type('api_token', api_token, (str, type(None)))
verify_type('rank_zero_only', rank_zero_only, bool)
verify_type('upload_artifacts', upload_artifacts, (bool, type(None)))
verify_type('upload_checkpoints', upload_checkpoints, bool)
verify_type('upload_artifacts', upload_artifacts, bool)
verify_type('base_namespace', base_namespace, str)

if not base_namespace:
Expand All @@ -90,35 +83,38 @@ def __init__(
self._project = project
self._api_token = api_token
self._rank_zero_only = rank_zero_only

if upload_artifacts is not None:
_warn_about_deprecated_upload_artifacts()
self._upload_checkpoints = upload_artifacts
else:
self._upload_checkpoints = upload_checkpoints

self._upload_artifacts = upload_artifacts
self._base_namespace = base_namespace
self._neptune_kwargs = neptune_kwargs

mode = self._neptune_kwargs.pop('mode', 'async')

self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0

self._mode: Optional[NEPTUNE_MODE_TYPE] = mode if self._enabled else 'debug'
self._mode = mode if self._enabled else 'debug'

self._neptune_run = None
self._base_handler = None

self._metrics_dict: Dict[str, int] = {} # used to prevent duplicate step logging

self._oom_observer: Optional['OOMObserver'] = None

super().__init__()

@property
def neptune_run(self):
"""Gets the Neptune run object from a NeptuneLogger instance.
To log additional metadata to the run, access a path inside the run and assign metadata
with ``=`` or other `Neptune logging methods <https://docs.neptune.ai/logging/methods/>`_.
You can log additional metadata to the run by accessing a path inside the run and assigning metadata to it
with "=" or [Neptune logging methods](https://docs.neptune.ai/logging/methods/).
Example:
from composer import Trainer
from composer.loggers import NeptuneLogger
neptune_logger = NeptuneLogger()
trainer = Trainer(loggers=neptune_logger, ...)
trainer.fit()
neptune_logger.neptune_run["some_metric"] = 1
trainer.close()
"""
from neptune import Run

Expand All @@ -135,10 +131,19 @@ def neptune_run(self):
def base_handler(self):
"""Gets a handler for the base logging namespace.
Use the handler to log extra metadata to the run and organize it under the base namespace
(default: "training"). You can operate on it like a run object: Access a path inside the
handler and assign metadata to it with ``=`` or other
`Neptune logging methods <https://docs.neptune.ai/logging/methods/>`_.
Use the handler to log extra metadata to the run and organize it under the base namespace (default: "training").
You can operate on it like a run object: Access a path inside the handler and assign metadata to it with "=" or
other [Neptune logging methods](https://docs.neptune.ai/logging/methods/).
Example:
from composer import Trainer
from composer.loggers import NeptuneLogger
neptune_logger = NeptuneLogger()
trainer = Trainer(loggers=neptune_logger, ...)
trainer.fit()
neptune_logger.base_handler["some_metric"] = 1
trainer.close()
Result: The value `1` is organized under "training/some_metric" inside the run.
"""
return self.neptune_run[self._base_namespace]

Expand All @@ -151,8 +156,6 @@ def init(self, state: 'State', logger: 'Logger') -> None:
self.neptune_run['sys/name'] = state.run_name
self.neptune_run[self.integration_version_key] = __version__

self._oom_observer = _find_oom_callback(state.callbacks)

def _sanitize_metrics(self, metrics: Dict[str, float], step: Optional[int]) -> Dict[str, float]:
"""Sanitize metrics to prevent duplicate step logging.
Expand Down Expand Up @@ -210,7 +213,7 @@ def log_traces(self, traces: Dict[str, Any]):

def can_upload_files(self) -> bool:
"""Whether the logger supports uploading files."""
return self._enabled and self._upload_checkpoints
return self._enabled and self._upload_artifacts

def upload_file(
self,
Expand All @@ -223,9 +226,6 @@ def upload_file(
if not self.can_upload_files():
return

if file_path.is_symlink():
return # skip symlinks

neptune_path = f'{self._base_namespace}/{remote_file_name}'
if self.neptune_run.exists(neptune_path) and not overwrite:

Expand All @@ -236,11 +236,7 @@ def upload_file(
return

del state # unused

from neptune.types import File

with open(str(file_path), 'rb') as fp:
self.base_handler[remote_file_name] = File.from_stream(fp, extension=file_path.suffix)
self.base_handler[remote_file_name].upload(str(file_path))

def download_file(
self,
Expand Down Expand Up @@ -303,16 +299,6 @@ def post_close(self) -> None:
self._neptune_run.stop()
self._neptune_run = None

if self._oom_observer:
self._log_oom_snapshots()

def _log_oom_snapshots(self) -> None:
if self._oom_observer is None or self._oom_observer.filename_config is None:
return

for file_name in self._oom_observer.filename_config.list_filenames():
self.base_handler[f'{NeptuneLogger.oom_snaphot_namespace}/{file_name}'].upload(file_name)


def _validate_image(img: Union[np.ndarray, torch.Tensor], channels_last: bool) -> np.ndarray:
img_numpy = img.data.cpu().numpy() if isinstance(img, torch.Tensor) else img
Expand All @@ -326,46 +312,4 @@ def _validate_image(img: Union[np.ndarray, torch.Tensor], channels_last: bool) -
if not channels_last:
img_numpy = np.moveaxis(img_numpy, 0, -1)

return _validate_image_value_range(img_numpy)


def _validate_image_value_range(img: np.ndarray) -> np.ndarray:
array_min = img.min()
array_max = img.max()

if (array_min >= 0 and 1 < array_max <= 255) or (array_min >= 0 and array_max <= 1):
return img

from neptune.common.warnings import NeptuneWarning, warn_once

warn_once(
'Image value range is not in the expected range of [0.0, 1.0] or [0, 255]. '
'This might be due to the presence of `transforms.Normalize` in the data pipeline. '
'Logged images may not display correctly in Neptune.',
exception=NeptuneWarning,
)

return _scale_image_to_0_255(img, array_min, array_max)


def _scale_image_to_0_255(img: np.ndarray, array_min: Union[int, float], array_max: Union[int, float]) -> np.ndarray:
scaled_image = 255 * (img - array_min) / (array_max - array_min)
return scaled_image.astype(np.uint8)


def _find_oom_callback(callbacks: List['Callback']) -> Optional['OOMObserver']:
from composer.callbacks import OOMObserver

for callback in callbacks:
if isinstance(callback, OOMObserver):
return callback
return None


def _warn_about_deprecated_upload_artifacts() -> None:
from neptune.common.warnings import NeptuneDeprecationWarning, warn_once
warn_once(
'The \'upload_artifacts\' parameter has been deprecated and will be removed in the next version. '
'Please use the \'upload_checkpoints\' parameter instead.',
exception=NeptuneDeprecationWarning,
)
return img_numpy
3 changes: 0 additions & 3 deletions docs/source/doctest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,6 @@
# Disable wandb
os.environ['WANDB_MODE'] = 'disabled'

# Disable neptune
os.environ['NEPTUNE_MODE'] = 'debug'

# Change the cwd to be the tempfile, so we don't pollute the documentation source folder
tmpdir = tempfile.mkdtemp()
cwd = os.path.abspath('.')
Expand Down
Loading

0 comments on commit 46ffa82

Please sign in to comment.