diff --git a/composer/_version.py b/composer/_version.py index c95d48d28a..c78074d7f4 100644 --- a/composer/_version.py +++ b/composer/_version.py @@ -3,4 +3,4 @@ """The Composer Version.""" -__version__ = '0.21.2' +__version__ = '0.22.0.dev0' diff --git a/composer/callbacks/oom_observer.py b/composer/callbacks/oom_observer.py index d9250c37e4..7d6292e079 100644 --- a/composer/callbacks/oom_observer.py +++ b/composer/callbacks/oom_observer.py @@ -2,17 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 """Generate a memory snapshot during an OutOfMemory exception.""" +from __future__ import annotations +import dataclasses import logging import os import pickle import warnings -from typing import Optional +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional import torch.cuda from packaging import version -from composer import State from composer.core import Callback, State from composer.loggers import Logger from composer.utils import ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, parse_uri @@ -22,6 +25,29 @@ __all__ = ['OOMObserver'] +@dataclass(frozen=True) +class SnapshotFileNameConfig: + """Configuration for the file names of the memory snapshot visualizations.""" + snapshot_file: str + trace_plot_file: str + segment_plot_file: str + segment_flamegraph_file: str + memory_flamegraph_file: str + + @classmethod + def from_file_name(cls, filename: str) -> 'SnapshotFileNameConfig': + return cls( + snapshot_file=filename + '_snapshot.pickle', + trace_plot_file=filename + '_trace_plot.html', + segment_plot_file=filename + '_segment_plot.html', + segment_flamegraph_file=filename + '_segment_flamegraph.svg', + memory_flamegraph_file=filename + '_memory_flamegraph.svg', + ) + + def list_filenames(self) -> List[str]: + return [getattr(self, field.name) for field in dataclasses.fields(self)] + + class OOMObserver(Callback): """Generate visualizations of the state of allocated memory during an OutOfMemory exception. @@ -94,6 +120,8 @@ def __init__( self._enabled = False warnings.warn('OOMObserver is supported after PyTorch 2.1.0. Disabling OOMObserver callback.') + self.filename_config: Optional[SnapshotFileNameConfig] = None + def init(self, state: State, logger: Logger) -> None: if not self._enabled: return @@ -117,17 +145,12 @@ def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int): assert self.filename assert self.folder_name, 'folder_name must be set in init' - filename = os.path.join( - self.folder_name, + filename = Path(self.folder_name) / Path( format_name_with_dist_and_time(self.filename, run_name=state.run_name, timestamp=state.timestamp), ) try: - snapshot_file = filename + '_snapshot.pickle' - trace_plot_file = filename + '_trace_plot.html' - segment_plot_file = filename + '_segment_plot.html' - segment_flamegraph_file = filename + '_segment_flamegraph.svg' - memory_flamegraph_file = filename + '_memory_flamegraph.svg' + self.filename_config = SnapshotFileNameConfig.from_file_name(str(filename)) log.info(f'Dumping OOMObserver visualizations') snapshot = torch.cuda.memory._snapshot() @@ -136,31 +159,25 @@ def oom_observer(device: int, alloc: int, device_alloc: int, device_free: int): log.info(f'No allocation is recorded in memory snapshot)') return - with open(snapshot_file, 'wb') as fd: + with open(self.filename_config.snapshot_file, 'wb') as fd: pickle.dump(snapshot, fd) - with open(trace_plot_file, 'w+') as fd: + with open(self.filename_config.trace_plot_file, 'w+') as fd: fd.write(torch.cuda._memory_viz.trace_plot(snapshot)) # type: ignore - with open(segment_plot_file, 'w+') as fd: + with open(self.filename_config.segment_plot_file, 'w+') as fd: fd.write(torch.cuda._memory_viz.segment_plot(snapshot)) # type: ignore - with open(segment_flamegraph_file, 'w+') as fd: + with open(self.filename_config.segment_flamegraph_file, 'w+') as fd: fd.write(torch.cuda._memory_viz.segments(snapshot)) # type: ignore - with open(memory_flamegraph_file, 'w+') as fd: + with open(self.filename_config.memory_flamegraph_file, 'w+') as fd: fd.write(torch.cuda._memory_viz.memory(snapshot)) # type: ignore log.info(f'Saved memory visualizations to local files with prefix = {filename} during OOM') if self.remote_path_in_bucket is not None: - for f in [ - snapshot_file, - trace_plot_file, - segment_plot_file, - segment_flamegraph_file, - memory_flamegraph_file, - ]: + for f in self.filename_config.list_filenames(): base_file_name = os.path.basename(f) remote_file_name = os.path.join(self.remote_path_in_bucket, base_file_name) remote_file_name = remote_file_name.lstrip('/') # remove leading slashes diff --git a/composer/loggers/mlflow_logger.py b/composer/loggers/mlflow_logger.py index 9ed6415dce..dcbed33d96 100644 --- a/composer/loggers/mlflow_logger.py +++ b/composer/loggers/mlflow_logger.py @@ -148,7 +148,7 @@ def init(self, state: State, logger: Logger) -> None: # Store the Composer run name in the MLFlow run tags so it can be retrieved for autoresume. self.tags = self.tags or {} - self.tags['run_name'] = state.run_name + self.tags['run_name'] = os.environ.get('RUN_NAME', state.run_name) # Adjust name and group based on `rank_zero_only`. if not self._rank_zero_only: @@ -171,16 +171,6 @@ def init(self, state: State, logger: Logger) -> None: output_format='list', ) - # Check for the old tag (`composer_run_name`) For backwards compatibility in case a run using the old - # tag fails and the run is resumed with a newer version of Composer that uses `run_name` instead of - # `composer_run_name`. - if len(existing_runs) == 0: - existing_runs = mlflow.search_runs( - experiment_ids=[self._experiment_id], - filter_string=f'tags.composer_run_name = "{state.run_name}"', - output_format='list', - ) - if len(existing_runs) > 0: self._run_id = existing_runs[0].info.run_id else: diff --git a/composer/loggers/neptune_logger.py b/composer/loggers/neptune_logger.py index bcabc7999d..abb354a9a9 100644 --- a/composer/loggers/neptune_logger.py +++ b/composer/loggers/neptune_logger.py @@ -9,24 +9,30 @@ import pathlib import warnings from functools import partial -from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Set, Union +from importlib.metadata import version +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Sequence, Set, Union import numpy as np import torch +from packaging.version import Version from composer._version import __version__ from composer.loggers import LoggerDestination -from composer.utils import MissingConditionalImportError, dist +from composer.utils import MissingConditionalImportError, VersionedDeprecationWarning, dist if TYPE_CHECKING: from composer import Logger from composer.core import State +NEPTUNE_MODE_TYPE = Literal['async', 'sync', 'offline', 'read-only', 'debug'] +NEPTUNE_VERSION_WITH_PROGRESS_BAR = Version('1.9.0') + class NeptuneLogger(LoggerDestination): """Log to `neptune.ai `_. - For more, see the [Neptune-Composer integration guide](https://docs.neptune.ai/integrations/composer/). + For instructions, see the + `integration guide `_. Args: project (str, optional): The name of your Neptune project, @@ -36,16 +42,15 @@ class NeptuneLogger(LoggerDestination): You can leave out this argument if you save your token to the ``NEPTUNE_API_TOKEN`` environment variable (recommended). You can find your API token in the user menu of the Neptune web app. - rank_zero_only (bool, optional): Whether to log only on the rank-zero process. - (default: ``True``). - upload_artifacts (bool, optional): Whether the logger should upload artifacts to Neptune. + rank_zero_only (bool): Whether to log only on the rank-zero process (default: ``True``). + upload_artifacts (bool, optional): Deprecated. See ``upload_checkpoints``. + upload_checkpoints (bool): Whether the logger should upload checkpoints to Neptune (default: ``False``). - base_namespace (str, optional): The name of the base namespace to log the metadata to. - (default: "training"). + base_namespace (str, optional): The name of the base namespace where the metadata + is logged (default: "training"). neptune_kwargs (Dict[str, Any], optional): Any additional keyword arguments to the ``neptune.init_run()`` function. For options, see the - `Run API reference `_ in the - Neptune docs. + `Run API reference `_. """ metric_namespace = 'metrics' hyperparam_namespace = 'hyperparameters' @@ -58,8 +63,10 @@ def __init__( project: Optional[str] = None, api_token: Optional[str] = None, rank_zero_only: bool = True, - upload_artifacts: bool = False, + upload_artifacts: Optional[bool] = None, + upload_checkpoints: bool = False, base_namespace: str = 'training', + mode: Optional[NEPTUNE_MODE_TYPE] = None, **neptune_kwargs, ) -> None: try: @@ -74,7 +81,8 @@ def __init__( verify_type('project', project, (str, type(None))) verify_type('api_token', api_token, (str, type(None))) verify_type('rank_zero_only', rank_zero_only, bool) - verify_type('upload_artifacts', upload_artifacts, bool) + verify_type('upload_artifacts', upload_artifacts, (bool, type(None))) + verify_type('upload_checkpoints', upload_checkpoints, bool) verify_type('base_namespace', base_namespace, str) if not base_namespace: @@ -83,15 +91,19 @@ def __init__( self._project = project self._api_token = api_token self._rank_zero_only = rank_zero_only - self._upload_artifacts = upload_artifacts + + if upload_artifacts is not None: + _warn_about_deprecated_upload_artifacts() + self._upload_checkpoints = upload_artifacts + else: + self._upload_checkpoints = upload_checkpoints + self._base_namespace = base_namespace self._neptune_kwargs = neptune_kwargs - mode = self._neptune_kwargs.pop('mode', 'async') - self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0 - self._mode = mode if self._enabled else 'debug' + self._mode: Optional[NEPTUNE_MODE_TYPE] = mode if self._enabled else 'debug' self._neptune_run = None self._base_handler = None @@ -104,17 +116,8 @@ def __init__( def neptune_run(self): """Gets the Neptune run object from a NeptuneLogger instance. - You can log additional metadata to the run by accessing a path inside the run and assigning metadata to it - with "=" or [Neptune logging methods](https://docs.neptune.ai/logging/methods/). - - Example: - from composer import Trainer - from composer.loggers import NeptuneLogger - neptune_logger = NeptuneLogger() - trainer = Trainer(loggers=neptune_logger, ...) - trainer.fit() - neptune_logger.neptune_run["some_metric"] = 1 - trainer.close() + To log additional metadata to the run, access a path inside the run and assign metadata + with ``=`` or other `Neptune logging methods `_. """ from neptune import Run @@ -131,19 +134,10 @@ def neptune_run(self): def base_handler(self): """Gets a handler for the base logging namespace. - Use the handler to log extra metadata to the run and organize it under the base namespace (default: "training"). - You can operate on it like a run object: Access a path inside the handler and assign metadata to it with "=" or - other [Neptune logging methods](https://docs.neptune.ai/logging/methods/). - - Example: - from composer import Trainer - from composer.loggers import NeptuneLogger - neptune_logger = NeptuneLogger() - trainer = Trainer(loggers=neptune_logger, ...) - trainer.fit() - neptune_logger.base_handler["some_metric"] = 1 - trainer.close() - Result: The value `1` is organized under "training/some_metric" inside the run. + Use the handler to log extra metadata to the run and organize it under the base namespace + (default: "training"). You can operate on it like a run object: Access a path inside the + handler and assign metadata to it with ``=`` or other + `Neptune logging methods `_. """ return self.neptune_run[self._base_namespace] @@ -213,7 +207,7 @@ def log_traces(self, traces: Dict[str, Any]): def can_upload_files(self) -> bool: """Whether the logger supports uploading files.""" - return self._enabled and self._upload_artifacts + return self._enabled and self._upload_checkpoints def upload_file( self, @@ -226,6 +220,9 @@ def upload_file( if not self.can_upload_files(): return + if file_path.is_symlink() or file_path.suffix.lower() == '.symlink': + return # skip symlinks + neptune_path = f'{self._base_namespace}/{remote_file_name}' if self.neptune_run.exists(neptune_path) and not overwrite: @@ -236,7 +233,11 @@ def upload_file( return del state # unused - self.base_handler[remote_file_name].upload(str(file_path)) + + from neptune.types import File + + with open(str(file_path), 'rb') as fp: + self.base_handler[remote_file_name] = File.from_stream(fp, extension=file_path.suffix) def download_file( self, @@ -245,7 +246,6 @@ def download_file( overwrite: bool = False, progress_bar: bool = True, ): - del progress_bar # not supported if not self._enabled: return @@ -266,7 +266,11 @@ def download_file( if not self.neptune_run.exists(file_path): raise FileNotFoundError(f'File {file_path} not found') - self.base_handler[remote_file_name].download(destination=destination) + if _is_progress_bar_enabled(): + self.base_handler[remote_file_name].download(destination=destination, progress_bar=progress_bar) + else: + del progress_bar + self.base_handler[remote_file_name].download(destination=destination) def log_images( self, @@ -312,4 +316,42 @@ def _validate_image(img: Union[np.ndarray, torch.Tensor], channels_last: bool) - if not channels_last: img_numpy = np.moveaxis(img_numpy, 0, -1) - return img_numpy + return _validate_image_value_range(img_numpy) + + +def _validate_image_value_range(img: np.ndarray) -> np.ndarray: + array_min = img.min() + array_max = img.max() + + if (array_min >= 0 and 1 < array_max <= 255) or (array_min >= 0 and array_max <= 1): + return img + + from neptune.common.warnings import NeptuneWarning, warn_once + + warn_once( + 'Image value range is not in the expected range of [0.0, 1.0] or [0, 255]. ' + 'This might be due to the presence of `transforms.Normalize` in the data pipeline. ' + 'Logged images may not display correctly in Neptune.', + exception=NeptuneWarning, + ) + + return _scale_image_to_0_255(img, array_min, array_max) + + +def _scale_image_to_0_255(img: np.ndarray, array_min: Union[int, float], array_max: Union[int, float]) -> np.ndarray: + scaled_image = 255 * (img - array_min) / (array_max - array_min) + return scaled_image.astype(np.uint8) + + +def _warn_about_deprecated_upload_artifacts() -> None: + warnings.warn( + VersionedDeprecationWarning( + 'The \'upload_artifacts\' parameter is deprecated and will be removed in the next version. ' + 'Use the \'upload_checkpoints\' parameter instead.', + remove_version='0.23', + ), + ) + + +def _is_progress_bar_enabled() -> bool: + return Version(version('neptune')) >= NEPTUNE_VERSION_WITH_PROGRESS_BAR diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 3827533e51..bf40bcddde 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2247,14 +2247,12 @@ def _ensure_metrics_device_and_dtype( return metrics def _compute_and_log_metrics(self, dataloader_label: str, metrics: Dict[str, Metric]): - """Computes metrics, logs the results, and updates the state with the deep-copied metrics. + """Computes metrics, logs the results, and updates the state with the metrics. Args: dataloader_label (str): The dataloader label. metrics (Dict[str, Metric]): The metrics to compute. """ - metrics = deepcopy(metrics) - # log computed metrics computed_metrics = {} for metric_name, metric in metrics.items(): diff --git a/docs/source/doctest_fixtures.py b/docs/source/doctest_fixtures.py index 86378a9f74..553d8d9b60 100644 --- a/docs/source/doctest_fixtures.py +++ b/docs/source/doctest_fixtures.py @@ -101,6 +101,9 @@ # Disable wandb os.environ['WANDB_MODE'] = 'disabled' +# Disable neptune +os.environ['NEPTUNE_MODE'] = 'debug' + # Change the cwd to be the tempfile, so we don't pollute the documentation source folder tmpdir = tempfile.mkdtemp() cwd = os.path.abspath('.') diff --git a/docs/source/trainer/file_uploading.rst b/docs/source/trainer/file_uploading.rst index b6224a9bd7..d6ad89d476 100644 --- a/docs/source/trainer/file_uploading.rst +++ b/docs/source/trainer/file_uploading.rst @@ -112,7 +112,7 @@ Composer includes three built-in LoggerDestinations to store artifacts: * The :class:`~composer.logger.neptune_logger.NeptuneLogger` can upload Composer training files as `Neptune Files `_, which are associated with the corresponding - Neptune project. + Neptune run. * The :class:`~composer.loggers.remote_uploader_downloader.RemoteUploaderDownloader` can upload Composer training files to any cloud storage backend or remote filesystem. We include integrations for AWS S3 and SFTP @@ -161,6 +161,30 @@ Weights & Biases Artifacts # Train! trainer.fit() +Neptune File upload +^^^^^^^^^^^^^^^^^^^ + +.. seealso:: + + The :class:`~composer.loggers.neptune_logger.NeptuneLogger` API Reference. + +.. testcode:: + :skipif: True + + from composer.loggers import NeptuneLogger + from composer import Trainer + + # Configure the Neptune logger + logger = NeptuneLogger( + upload_checkpoints=True, # enable logging of checkpoint files + ) + + # Define the trainer + trainer = Trainer(..., loggers=logger) + + # Train + trainer.fit() + S3 Objects ^^^^^^^^^^ diff --git a/docs/source/trainer/logging.rst b/docs/source/trainer/logging.rst index b44f919016..efb6a1f775 100644 --- a/docs/source/trainer/logging.rst +++ b/docs/source/trainer/logging.rst @@ -29,34 +29,40 @@ and also saves them to the file ``log.txt``. .. testsetup:: - :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED + :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED or not _NEPTUNE_INSTALLED import os + import logging + logging.getLogger("neptune").setLevel(logging.CRITICAL) + + os.environ["NEPTUNE_MODE"] = "debug" os.environ["WANDB_MODE"] = "disabled" os.environ["COMET_API_KEY"] = "" os.environ["MLFLOW_TRACKING_URI"] = "" .. testcode:: - :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED + :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED or not _NEPTUNE_INSTALLED from composer import Trainer from composer.loggers import WandBLogger, CometMLLogger, MLFlowLogger, NeptuneLogger, FileLogger + wandb_logger = WandBLogger() cometml_logger = CometMLLogger() mlflow_logger = MLFlowLogger() + neptune_logger = NeptuneLogger() file_logger = FileLogger(filename="log.txt") trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, - loggers=[wandb_logger, cometml_logger, mlflow_logger, file_logger], + loggers=[wandb_logger, cometml_logger, mlflow_logger, neptune_logger, file_logger], ) .. testcleanup:: - :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED + :skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED or not _NEPTUNE_INSTALLED trainer.engine.close() os.remove("log.txt") diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py index 22b0445e2d..b97c108633 100644 --- a/tests/loggers/test_mlflow_logger.py +++ b/tests/loggers/test_mlflow_logger.py @@ -6,7 +6,7 @@ import os import time from pathlib import Path -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -190,24 +190,52 @@ def test_mlflow_experiment_init_existing_composer_run(monkeypatch): assert test_logger._run_id == existing_id -def test_mlflow_experiment_init_existing_composer_run_with_old_tag(monkeypatch): - """ Test that an existing MLFlow run is used if one exists with the old `composer_run_name` tag. - """ +@pytest.fixture +def mock_mlflow_client(): + with patch('mlflow.tracking.MlflowClient') as MockClient: + mock_create_run = MagicMock(return_value=MagicMock(info=MagicMock(run_id='mock-run-id'))) + MockClient.return_value.create_run = mock_create_run + yield MockClient + + +def test_mlflow_logger_uses_env_var_run_name(monkeypatch, mock_mlflow_client): + """Test that MLFlowLogger uses the 'RUN_NAME' environment variable if set.""" mlflow = pytest.importorskip('mlflow') monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock()) monkeypatch.setattr(mlflow, 'start_run', MagicMock()) + from composer.loggers.mlflow_logger import MLFlowLogger + mock_state = MagicMock() + mock_state.run_name = 'dummy-run-name' + monkeypatch.setenv('RUN_NAME', 'env-run-name') + + test_logger = MLFlowLogger() + test_logger.init(state=mock_state, logger=MagicMock()) + + assert test_logger.tags is not None + assert test_logger.tags['run_name'] == 'env-run-name' + monkeypatch.delenv('RUN_NAME') + + +def test_mlflow_logger_uses_state_run_name_if_no_env_var_set(monkeypatch, mock_mlflow_client): + """Test that MLFlowLogger uses the state's run name if no 'RUN_NAME' environment variable is set.""" + mlflow = pytest.importorskip('mlflow') + + monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock()) + monkeypatch.setattr(mlflow, 'start_run', MagicMock()) mock_state = MagicMock() - mock_state.composer_run_name = 'dummy-run-name' + mock_state.run_name = 'state-run-name' existing_id = 'dummy-id' mock_search_runs = MagicMock(return_value=[MagicMock(info=MagicMock(run_id=existing_id))]) monkeypatch.setattr(mlflow, 'search_runs', mock_search_runs) + from composer.loggers.mlflow_logger import MLFlowLogger test_logger = MLFlowLogger() test_logger.init(state=mock_state, logger=MagicMock()) - assert test_logger._run_id == existing_id + assert test_logger.tags is not None + assert test_logger.tags['run_name'] == 'state-run-name' def test_mlflow_experiment_set_up(tmp_path): diff --git a/tests/loggers/test_neptune_logger.py b/tests/loggers/test_neptune_logger.py index d6ab2fcdf4..2d6fc150f9 100644 --- a/tests/loggers/test_neptune_logger.py +++ b/tests/loggers/test_neptune_logger.py @@ -1,11 +1,13 @@ # Copyright 2022 MosaicML Composer authors # SPDX-License-Identifier: Apache-2.0 +import contextlib import os import uuid from pathlib import Path -from typing import Sequence +from typing import Generator, Sequence from unittest.mock import MagicMock, patch +import numpy as np import pytest import torch from torch.utils.data import DataLoader @@ -28,7 +30,7 @@ def test_neptune_logger() -> NeptuneLogger: api_token=neptune_api_token, rank_zero_only=False, mode='debug', - upload_artifacts=True, + upload_checkpoints=True, ) return neptune_logger @@ -153,3 +155,51 @@ def test_neptune_log_image(test_neptune_logger): test_neptune_logger.post_close() assert mock_extend.call_count == 2 * len(image_variants) # One set of torch tensors, one set of numpy arrays + + +def test_neptune_logger_doesnt_upload_symlinks(test_neptune_logger, dummy_state): + with _manage_symlink_creation('test.txt') as symlink_name: + test_neptune_logger.upload_file( + state=dummy_state, + remote_file_name='test_symlink', + file_path=Path(symlink_name), + ) + assert not test_neptune_logger.neptune_run.exists(f'{test_neptune_logger._base_namespace}/test_symlink') + + +@contextlib.contextmanager +def _manage_symlink_creation(file_name: str) -> Generator[str, None, None]: + with open(file_name, 'w') as f: + f.write('This is a test file.') + + symlink_name = 'test_symlink.txt' + + os.symlink(file_name, symlink_name) + + assert Path(symlink_name).is_symlink() + + yield symlink_name + + os.remove(symlink_name) + os.remove(file_name) + + +def test_neptune_log_image_warns_about_improper_value_range(test_neptune_logger): + image = np.ones((4, 4)) * 300 + with pytest.warns() as record: + test_neptune_logger.log_images(images=image) + assert 'Image value range is not in the expected range of [0.0, 1.0] or [0, 255].' in str(record[0].message) + + +@patch('composer.loggers.neptune_logger._scale_image_to_0_255', return_value=np.ones((4, 4))) +def test_neptune_log_image_scales_improper_image(mock_scale_img, test_neptune_logger): + image_variants = [ + np.ones((4, 4)) * 300, + np.ones((4, 4)) * -1, + np.identity(4) * 300 - 1, + ] + + for image in image_variants: + test_neptune_logger.log_images(images=image) + mock_scale_img.assert_called_once() + mock_scale_img.reset_mock()