diff --git a/composer/_version.py b/composer/_version.py
index c95d48d28a..c78074d7f4 100644
--- a/composer/_version.py
+++ b/composer/_version.py
@@ -3,4 +3,4 @@
"""The Composer Version."""
-__version__ = '0.21.2'
+__version__ = '0.22.0.dev0'
diff --git a/composer/callbacks/oom_observer.py b/composer/callbacks/oom_observer.py
index d9250c37e4..7d6292e079 100644
--- a/composer/callbacks/oom_observer.py
+++ b/composer/callbacks/oom_observer.py
@@ -2,17 +2,20 @@
# SPDX-License-Identifier: Apache-2.0
"""Generate a memory snapshot during an OutOfMemory exception."""
+from __future__ import annotations
+import dataclasses
import logging
import os
import pickle
import warnings
-from typing import Optional
+from dataclasses import dataclass
+from pathlib import Path
+from typing import List, Optional
import torch.cuda
from packaging import version
-from composer import State
from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, parse_uri
@@ -22,6 +25,29 @@
__all__ = ['OOMObserver']
+@dataclass(frozen=True)
+class SnapshotFileNameConfig:
+ """Configuration for the file names of the memory snapshot visualizations."""
+ snapshot_file: str
+ trace_plot_file: str
+ segment_plot_file: str
+ segment_flamegraph_file: str
+ memory_flamegraph_file: str
+
+ @classmethod
+ def from_file_name(cls, filename: str) -> 'SnapshotFileNameConfig':
+ return cls(
+ snapshot_file=filename + '_snapshot.pickle',
+ trace_plot_file=filename + '_trace_plot.html',
+ segment_plot_file=filename + '_segment_plot.html',
+ segment_flamegraph_file=filename + '_segment_flamegraph.svg',
+ memory_flamegraph_file=filename + '_memory_flamegraph.svg',
+ )
+
+ def list_filenames(self) -> List[str]:
+ return [getattr(self, field.name) for field in dataclasses.fields(self)]
+
+
class OOMObserver(Callback):
"""Generate visualizations of the state of allocated memory during an OutOfMemory exception.
@@ -94,6 +120,8 @@ def __init__(
self._enabled = False
warnings.warn('OOMObserver is supported after PyTorch 2.1.0. Disabling OOMObserver callback.')
+ self.filename_config: Optional[SnapshotFileNameConfig] = None
+
def init(self, state: State, logger: Logger) -> None:
if not self._enabled:
return
@@ -117,17 +145,12 @@ def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int):
assert self.filename
assert self.folder_name, 'folder_name must be set in init'
- filename = os.path.join(
- self.folder_name,
+ filename = Path(self.folder_name) / Path(
format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=state.timestamp),
)
try:
- snapshot_file = filename + '_snapshot.pickle'
- trace_plot_file = filename + '_trace_plot.html'
- segment_plot_file = filename + '_segment_plot.html'
- segment_flamegraph_file = filename + '_segment_flamegraph.svg'
- memory_flamegraph_file = filename + '_memory_flamegraph.svg'
+ self.filename_config = SnapshotFileNameConfig.from_file_name(str(filename))
log.info(f'Dumping OOMObserver visualizations')
snapshot = torch.cuda.memory._snapshot()
@@ -136,31 +159,25 @@ def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int):
log.info(f'No allocation is recorded in memory snapshot)')
return
- with open(snapshot_file, 'wb') as fd:
+ with open(self.filename_config.snapshot_file, 'wb') as fd:
pickle.dump(snapshot, fd)
- with open(trace_plot_file, 'w+') as fd:
+ with open(self.filename_config.trace_plot_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.trace_plot(snapshot)) # type: ignore
- with open(segment_plot_file, 'w+') as fd:
+ with open(self.filename_config.segment_plot_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.segment_plot(snapshot)) # type: ignore
- with open(segment_flamegraph_file, 'w+') as fd:
+ with open(self.filename_config.segment_flamegraph_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.segments(snapshot)) # type: ignore
- with open(memory_flamegraph_file, 'w+') as fd:
+ with open(self.filename_config.memory_flamegraph_file, 'w+') as fd:
fd.write(torch.cuda._memory_viz.memory(snapshot)) # type: ignore
log.info(f'Saved memory visualizations to local files with prefix = {filename} during OOM')
if self.remote_path_in_bucket is not None:
- for f in [
- snapshot_file,
- trace_plot_file,
- segment_plot_file,
- segment_flamegraph_file,
- memory_flamegraph_file,
- ]:
+ for f in self.filename_config.list_filenames():
base_file_name = os.path.basename(f)
remote_file_name = os.path.join(self.remote_path_in_bucket, base_file_name)
remote_file_name = remote_file_name.lstrip('/') # remove leading slashes
diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py
index 9ed6415dce..dcbed33d96 100644
--- a/composer/loggers/mlflow_logger.py
+++ b/composer/loggers/mlflow_logger.py
@@ -148,7 +148,7 @@ def init(self, state: State, logger: Logger) -> None:
# Store the Composer run name in the MLFlow run tags so it can be retrieved for autoresume.
self.tags = self.tags or {}
- self.tags['run_name'] = state.run_name
+ self.tags['run_name'] = os.environ.get('RUN_NAME', state.run_name)
# Adjust name and group based on `rank_zero_only`.
if not self._rank_zero_only:
@@ -171,16 +171,6 @@ def init(self, state: State, logger: Logger) -> None:
output_format='list',
)
- # Check for the old tag (`composer_run_name`) For backwards compatibility in case a run using the old
- # tag fails and the run is resumed with a newer version of Composer that uses `run_name` instead of
- # `composer_run_name`.
- if len(existing_runs) == 0:
- existing_runs = mlflow.search_runs(
- experiment_ids=[self._experiment_id],
- filter_string=f'tags.composer_run_name = "{state.run_name}"',
- output_format='list',
- )
-
if len(existing_runs) > 0:
self._run_id = existing_runs[0].info.run_id
else:
diff --git a/composer/loggers/neptune_logger.py b/composer/loggers/neptune_logger.py
index bcabc7999d..abb354a9a9 100644
--- a/composer/loggers/neptune_logger.py
+++ b/composer/loggers/neptune_logger.py
@@ -9,24 +9,30 @@
import pathlib
import warnings
from functools import partial
-from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Union
+from importlib.metadata import version
+from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Set, Union
import numpy as np
import torch
+from packaging.version import Version
from composer._version import __version__
from composer.loggers import LoggerDestination
-from composer.utils import MissingConditionalImportError, dist
+from composer.utils import MissingConditionalImportError, VersionedDeprecationWarning, dist
if TYPE_CHECKING:
from composer import Logger
from composer.core import State
+NEPTUNE_MODE_TYPE = Literal['async', 'sync', 'offline', 'read-only', 'debug']
+NEPTUNE_VERSION_WITH_PROGRESS_BAR = Version('1.9.0')
+
class NeptuneLogger(LoggerDestination):
"""Log to `neptune.ai `_.
- For more, see the [Neptune-Composer integration guide](https://docs.neptune.ai/integrations/composer/).
+ For instructions, see the
+ `integration guide `_.
Args:
project (str, optional): The name of your Neptune project,
@@ -36,16 +42,15 @@ class NeptuneLogger(LoggerDestination):
You can leave out this argument if you save your token to the
``NEPTUNE_API_TOKEN`` environment variable (recommended).
You can find your API token in the user menu of the Neptune web app.
- rank_zero_only (bool, optional): Whether to log only on the rank-zero process.
- (default: ``True``).
- upload_artifacts (bool, optional): Whether the logger should upload artifacts to Neptune.
+ rank_zero_only (bool): Whether to log only on the rank-zero process (default: ``True``).
+ upload_artifacts (bool, optional): Deprecated. See ``upload_checkpoints``.
+ upload_checkpoints (bool): Whether the logger should upload checkpoints to Neptune
(default: ``False``).
- base_namespace (str, optional): The name of the base namespace to log the metadata to.
- (default: "training").
+ base_namespace (str, optional): The name of the base namespace where the metadata
+ is logged (default: "training").
neptune_kwargs (Dict[str, Any], optional): Any additional keyword arguments to the
``neptune.init_run()`` function. For options, see the
- `Run API reference `_ in the
- Neptune docs.
+ `Run API reference `_.
"""
metric_namespace = 'metrics'
hyperparam_namespace = 'hyperparameters'
@@ -58,8 +63,10 @@ def __init__(
project: Optional[str] = None,
api_token: Optional[str] = None,
rank_zero_only: bool = True,
- upload_artifacts: bool = False,
+ upload_artifacts: Optional[bool] = None,
+ upload_checkpoints: bool = False,
base_namespace: str = 'training',
+ mode: Optional[NEPTUNE_MODE_TYPE] = None,
**neptune_kwargs,
) -> None:
try:
@@ -74,7 +81,8 @@ def __init__(
verify_type('project', project, (str, type(None)))
verify_type('api_token', api_token, (str, type(None)))
verify_type('rank_zero_only', rank_zero_only, bool)
- verify_type('upload_artifacts', upload_artifacts, bool)
+ verify_type('upload_artifacts', upload_artifacts, (bool, type(None)))
+ verify_type('upload_checkpoints', upload_checkpoints, bool)
verify_type('base_namespace', base_namespace, str)
if not base_namespace:
@@ -83,15 +91,19 @@ def __init__(
self._project = project
self._api_token = api_token
self._rank_zero_only = rank_zero_only
- self._upload_artifacts = upload_artifacts
+
+ if upload_artifacts is not None:
+ _warn_about_deprecated_upload_artifacts()
+ self._upload_checkpoints = upload_artifacts
+ else:
+ self._upload_checkpoints = upload_checkpoints
+
self._base_namespace = base_namespace
self._neptune_kwargs = neptune_kwargs
- mode = self._neptune_kwargs.pop('mode', 'async')
-
self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0
- self._mode = mode if self._enabled else 'debug'
+ self._mode: Optional[NEPTUNE_MODE_TYPE] = mode if self._enabled else 'debug'
self._neptune_run = None
self._base_handler = None
@@ -104,17 +116,8 @@ def __init__(
def neptune_run(self):
"""Gets the Neptune run object from a NeptuneLogger instance.
- You can log additional metadata to the run by accessing a path inside the run and assigning metadata to it
- with "=" or [Neptune logging methods](https://docs.neptune.ai/logging/methods/).
-
- Example:
- from composer import Trainer
- from composer.loggers import NeptuneLogger
- neptune_logger = NeptuneLogger()
- trainer = Trainer(loggers=neptune_logger, ...)
- trainer.fit()
- neptune_logger.neptune_run["some_metric"] = 1
- trainer.close()
+ To log additional metadata to the run, access a path inside the run and assign metadata
+ with ``=`` or other `Neptune logging methods `_.
"""
from neptune import Run
@@ -131,19 +134,10 @@ def neptune_run(self):
def base_handler(self):
"""Gets a handler for the base logging namespace.
- Use the handler to log extra metadata to the run and organize it under the base namespace (default: "training").
- You can operate on it like a run object: Access a path inside the handler and assign metadata to it with "=" or
- other [Neptune logging methods](https://docs.neptune.ai/logging/methods/).
-
- Example:
- from composer import Trainer
- from composer.loggers import NeptuneLogger
- neptune_logger = NeptuneLogger()
- trainer = Trainer(loggers=neptune_logger, ...)
- trainer.fit()
- neptune_logger.base_handler["some_metric"] = 1
- trainer.close()
- Result: The value `1` is organized under "training/some_metric" inside the run.
+ Use the handler to log extra metadata to the run and organize it under the base namespace
+ (default: "training"). You can operate on it like a run object: Access a path inside the
+ handler and assign metadata to it with ``=`` or other
+ `Neptune logging methods `_.
"""
return self.neptune_run[self._base_namespace]
@@ -213,7 +207,7 @@ def log_traces(self, traces: Dict[str, Any]):
def can_upload_files(self) -> bool:
"""Whether the logger supports uploading files."""
- return self._enabled and self._upload_artifacts
+ return self._enabled and self._upload_checkpoints
def upload_file(
self,
@@ -226,6 +220,9 @@ def upload_file(
if not self.can_upload_files():
return
+ if file_path.is_symlink() or file_path.suffix.lower() == '.symlink':
+ return # skip symlinks
+
neptune_path = f'{self._base_namespace}/{remote_file_name}'
if self.neptune_run.exists(neptune_path) and not overwrite:
@@ -236,7 +233,11 @@ def upload_file(
return
del state # unused
- self.base_handler[remote_file_name].upload(str(file_path))
+
+ from neptune.types import File
+
+ with open(str(file_path), 'rb') as fp:
+ self.base_handler[remote_file_name] = File.from_stream(fp, extension=file_path.suffix)
def download_file(
self,
@@ -245,7 +246,6 @@ def download_file(
overwrite: bool = False,
progress_bar: bool = True,
):
- del progress_bar # not supported
if not self._enabled:
return
@@ -266,7 +266,11 @@ def download_file(
if not self.neptune_run.exists(file_path):
raise FileNotFoundError(f'File {file_path} not found')
- self.base_handler[remote_file_name].download(destination=destination)
+ if _is_progress_bar_enabled():
+ self.base_handler[remote_file_name].download(destination=destination, progress_bar=progress_bar)
+ else:
+ del progress_bar
+ self.base_handler[remote_file_name].download(destination=destination)
def log_images(
self,
@@ -312,4 +316,42 @@ def _validate_image(img: Union[np.ndarray, torch.Tensor], channels_last: bool) -
if not channels_last:
img_numpy = np.moveaxis(img_numpy, 0, -1)
- return img_numpy
+ return _validate_image_value_range(img_numpy)
+
+
+def _validate_image_value_range(img: np.ndarray) -> np.ndarray:
+ array_min = img.min()
+ array_max = img.max()
+
+ if (array_min >= 0 and 1 < array_max <= 255) or (array_min >= 0 and array_max <= 1):
+ return img
+
+ from neptune.common.warnings import NeptuneWarning, warn_once
+
+ warn_once(
+ 'Image value range is not in the expected range of [0.0, 1.0] or [0, 255]. '
+ 'This might be due to the presence of `transforms.Normalize` in the data pipeline. '
+ 'Logged images may not display correctly in Neptune.',
+ exception=NeptuneWarning,
+ )
+
+ return _scale_image_to_0_255(img, array_min, array_max)
+
+
+def _scale_image_to_0_255(img: np.ndarray, array_min: Union[int, float], array_max: Union[int, float]) -> np.ndarray:
+ scaled_image = 255 * (img - array_min) / (array_max - array_min)
+ return scaled_image.astype(np.uint8)
+
+
+def _warn_about_deprecated_upload_artifacts() -> None:
+ warnings.warn(
+ VersionedDeprecationWarning(
+ 'The \'upload_artifacts\' parameter is deprecated and will be removed in the next version. '
+ 'Use the \'upload_checkpoints\' parameter instead.',
+ remove_version='0.23',
+ ),
+ )
+
+
+def _is_progress_bar_enabled() -> bool:
+ return Version(version('neptune')) >= NEPTUNE_VERSION_WITH_PROGRESS_BAR
diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py
index 3827533e51..bf40bcddde 100644
--- a/composer/trainer/trainer.py
+++ b/composer/trainer/trainer.py
@@ -2247,14 +2247,12 @@ def _ensure_metrics_device_and_dtype(
return metrics
def _compute_and_log_metrics(self, dataloader_label: str, metrics: Dict[str, Metric]):
- """Computes metrics, logs the results, and updates the state with the deep-copied metrics.
+ """Computes metrics, logs the results, and updates the state with the metrics.
Args:
dataloader_label (str): The dataloader label.
metrics (Dict[str, Metric]): The metrics to compute.
"""
- metrics = deepcopy(metrics)
-
# log computed metrics
computed_metrics = {}
for metric_name, metric in metrics.items():
diff --git a/docs/source/doctest_fixtures.py b/docs/source/doctest_fixtures.py
index 86378a9f74..553d8d9b60 100644
--- a/docs/source/doctest_fixtures.py
+++ b/docs/source/doctest_fixtures.py
@@ -101,6 +101,9 @@
# Disable wandb
os.environ['WANDB_MODE'] = 'disabled'
+# Disable neptune
+os.environ['NEPTUNE_MODE'] = 'debug'
+
# Change the cwd to be the tempfile, so we don't pollute the documentation source folder
tmpdir = tempfile.mkdtemp()
cwd = os.path.abspath('.')
diff --git a/docs/source/trainer/file_uploading.rst b/docs/source/trainer/file_uploading.rst
index b6224a9bd7..d6ad89d476 100644
--- a/docs/source/trainer/file_uploading.rst
+++ b/docs/source/trainer/file_uploading.rst
@@ -112,7 +112,7 @@ Composer includes three built-in LoggerDestinations to store artifacts:
* The :class:`~composer.logger.neptune_logger.NeptuneLogger` can upload Composer training files
as `Neptune Files `_, which are associated with the corresponding
- Neptune project.
+ Neptune run.
* The :class:`~composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader` can upload Composer training files
to any cloud storage backend or remote filesystem. We include integrations for AWS S3 and SFTP
@@ -161,6 +161,30 @@ Weights & Biases Artifacts
# Train!
trainer.fit()
+Neptune File upload
+^^^^^^^^^^^^^^^^^^^
+
+.. seealso::
+
+ The :class:`~composer.loggers.neptune_logger.NeptuneLogger` API Reference.
+
+.. testcode::
+ :skipif: True
+
+ from composer.loggers import NeptuneLogger
+ from composer import Trainer
+
+ # Configure the Neptune logger
+ logger = NeptuneLogger(
+ upload_checkpoints=True, # enable logging of checkpoint files
+ )
+
+ # Define the trainer
+ trainer = Trainer(..., loggers=logger)
+
+ # Train
+ trainer.fit()
+
S3 Objects
^^^^^^^^^^
diff --git a/docs/source/trainer/logging.rst b/docs/source/trainer/logging.rst
index b44f919016..efb6a1f775 100644
--- a/docs/source/trainer/logging.rst
+++ b/docs/source/trainer/logging.rst
@@ -29,34 +29,40 @@ and also saves them to the file
``log.txt``.
.. testsetup::
- :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED
+ :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED or not _NEPTUNE_INSTALLED
import os
+ import logging
+ logging.getLogger("neptune").setLevel(logging.CRITICAL)
+
+ os.environ["NEPTUNE_MODE"] = "debug"
os.environ["WANDB_MODE"] = "disabled"
os.environ["COMET_API_KEY"] = ""
os.environ["MLFLOW_TRACKING_URI"] = ""
.. testcode::
- :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED
+ :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED or not _NEPTUNE_INSTALLED
from composer import Trainer
from composer.loggers import WandBLogger, CometMLLogger, MLFlowLogger, NeptuneLogger, FileLogger
+
wandb_logger = WandBLogger()
cometml_logger = CometMLLogger()
mlflow_logger = MLFlowLogger()
+ neptune_logger = NeptuneLogger()
file_logger = FileLogger(filename="log.txt")
trainer = Trainer(
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
- loggers=[wandb_logger, cometml_logger, mlflow_logger, file_logger],
+ loggers=[wandb_logger, cometml_logger, mlflow_logger, neptune_logger, file_logger],
)
.. testcleanup::
- :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED
+ :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED or not _NEPTUNE_INSTALLED
trainer.engine.close()
os.remove("log.txt")
diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py
index 22b0445e2d..b97c108633 100644
--- a/tests/loggers/test_mlflow_logger.py
+++ b/tests/loggers/test_mlflow_logger.py
@@ -6,7 +6,7 @@
import os
import time
from pathlib import Path
-from unittest.mock import MagicMock
+from unittest.mock import MagicMock, patch
import numpy as np
import pytest
@@ -190,24 +190,52 @@ def test_mlflow_experiment_init_existing_composer_run(monkeypatch):
assert test_logger._run_id == existing_id
-def test_mlflow_experiment_init_existing_composer_run_with_old_tag(monkeypatch):
- """ Test that an existing MLFlow run is used if one exists with the old `composer_run_name` tag.
- """
+@pytest.fixture
+def mock_mlflow_client():
+ with patch('mlflow.tracking.MlflowClient') as MockClient:
+ mock_create_run = MagicMock(return_value=MagicMock(info=MagicMock(run_id='mock-run-id')))
+ MockClient.return_value.create_run = mock_create_run
+ yield MockClient
+
+
+def test_mlflow_logger_uses_env_var_run_name(monkeypatch, mock_mlflow_client):
+ """Test that MLFlowLogger uses the 'RUN_NAME' environment variable if set."""
mlflow = pytest.importorskip('mlflow')
monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
monkeypatch.setattr(mlflow, 'start_run', MagicMock())
+ from composer.loggers.mlflow_logger import MLFlowLogger
+ mock_state = MagicMock()
+ mock_state.run_name = 'dummy-run-name'
+ monkeypatch.setenv('RUN_NAME', 'env-run-name')
+
+ test_logger = MLFlowLogger()
+ test_logger.init(state=mock_state, logger=MagicMock())
+
+ assert test_logger.tags is not None
+ assert test_logger.tags['run_name'] == 'env-run-name'
+ monkeypatch.delenv('RUN_NAME')
+
+
+def test_mlflow_logger_uses_state_run_name_if_no_env_var_set(monkeypatch, mock_mlflow_client):
+ """Test that MLFlowLogger uses the state's run name if no 'RUN_NAME' environment variable is set."""
+ mlflow = pytest.importorskip('mlflow')
+
+ monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
+ monkeypatch.setattr(mlflow, 'start_run', MagicMock())
mock_state = MagicMock()
- mock_state.composer_run_name = 'dummy-run-name'
+ mock_state.run_name = 'state-run-name'
existing_id = 'dummy-id'
mock_search_runs = MagicMock(return_value=[MagicMock(info=MagicMock(run_id=existing_id))])
monkeypatch.setattr(mlflow, 'search_runs', mock_search_runs)
+ from composer.loggers.mlflow_logger import MLFlowLogger
test_logger = MLFlowLogger()
test_logger.init(state=mock_state, logger=MagicMock())
- assert test_logger._run_id == existing_id
+ assert test_logger.tags is not None
+ assert test_logger.tags['run_name'] == 'state-run-name'
def test_mlflow_experiment_set_up(tmp_path):
diff --git a/tests/loggers/test_neptune_logger.py b/tests/loggers/test_neptune_logger.py
index d6ab2fcdf4..2d6fc150f9 100644
--- a/tests/loggers/test_neptune_logger.py
+++ b/tests/loggers/test_neptune_logger.py
@@ -1,11 +1,13 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
+import contextlib
import os
import uuid
from pathlib import Path
-from typing import Sequence
+from typing import Generator, Sequence
from unittest.mock import MagicMock, patch
+import numpy as np
import pytest
import torch
from torch.utils.data import DataLoader
@@ -28,7 +30,7 @@ def test_neptune_logger() -> NeptuneLogger:
api_token=neptune_api_token,
rank_zero_only=False,
mode='debug',
- upload_artifacts=True,
+ upload_checkpoints=True,
)
return neptune_logger
@@ -153,3 +155,51 @@ def test_neptune_log_image(test_neptune_logger):
test_neptune_logger.post_close()
assert mock_extend.call_count == 2 * len(image_variants) # One set of torch tensors, one set of numpy arrays
+
+
+def test_neptune_logger_doesnt_upload_symlinks(test_neptune_logger, dummy_state):
+ with _manage_symlink_creation('test.txt') as symlink_name:
+ test_neptune_logger.upload_file(
+ state=dummy_state,
+ remote_file_name='test_symlink',
+ file_path=Path(symlink_name),
+ )
+ assert not test_neptune_logger.neptune_run.exists(f'{test_neptune_logger._base_namespace}/test_symlink')
+
+
+@contextlib.contextmanager
+def _manage_symlink_creation(file_name: str) -> Generator[str, None, None]:
+ with open(file_name, 'w') as f:
+ f.write('This is a test file.')
+
+ symlink_name = 'test_symlink.txt'
+
+ os.symlink(file_name, symlink_name)
+
+ assert Path(symlink_name).is_symlink()
+
+ yield symlink_name
+
+ os.remove(symlink_name)
+ os.remove(file_name)
+
+
+def test_neptune_log_image_warns_about_improper_value_range(test_neptune_logger):
+ image = np.ones((4, 4)) * 300
+ with pytest.warns() as record:
+ test_neptune_logger.log_images(images=image)
+ assert 'Image value range is not in the expected range of [0.0, 1.0] or [0, 255].' in str(record[0].message)
+
+
+@patch('composer.loggers.neptune_logger._scale_image_to_0_255', return_value=np.ones((4, 4)))
+def test_neptune_log_image_scales_improper_image(mock_scale_img, test_neptune_logger):
+ image_variants = [
+ np.ones((4, 4)) * 300,
+ np.ones((4, 4)) * -1,
+ np.identity(4) * 300 - 1,
+ ]
+
+ for image in image_variants:
+ test_neptune_logger.log_images(images=image)
+ mock_scale_img.assert_called_once()
+ mock_scale_img.reset_mock()