Skip to content

Commit

Permalink
Fix MLFlow Tag Name for Resumption (#3194)
Browse files Browse the repository at this point in the history
* quick patch

* pytest

* rm outdated test

* pytest fix

* pytest fix

* pytest all green

* patch

* cleanup

* more mocks

* ling :(

* code quality

* isort

* yapf

* clean
  • Loading branch information
KuuCi authored Apr 16, 2024
1 parent 6f84caa commit b44c083
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 17 deletions.
12 changes: 1 addition & 11 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def init(self, state: State, logger: Logger) -> None:

# Store the Composer run name in the MLFlow run tags so it can be retrieved for autoresume.
self.tags = self.tags or {}
self.tags['run_name'] = state.run_name
self.tags['run_name'] = os.environ.get('RUN_NAME', state.run_name)

# Adjust name and group based on `rank_zero_only`.
if not self._rank_zero_only:
Expand All @@ -171,16 +171,6 @@ def init(self, state: State, logger: Logger) -> None:
output_format='list',
)

# Check for the old tag (`composer_run_name`) For backwards compatibility in case a run using the old
# tag fails and the run is resumed with a newer version of Composer that uses `run_name` instead of
# `composer_run_name`.
if len(existing_runs) == 0:
existing_runs = mlflow.search_runs(
experiment_ids=[self._experiment_id],
filter_string=f'tags.composer_run_name = "{state.run_name}"',
output_format='list',
)

if len(existing_runs) > 0:
self._run_id = existing_runs[0].info.run_id
else:
Expand Down
40 changes: 34 additions & 6 deletions tests/loggers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import time
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
Expand Down Expand Up @@ -190,24 +190,52 @@ def test_mlflow_experiment_init_existing_composer_run(monkeypatch):
assert test_logger._run_id == existing_id


def test_mlflow_experiment_init_existing_composer_run_with_old_tag(monkeypatch):
""" Test that an existing MLFlow run is used if one exists with the old `composer_run_name` tag.
"""
@pytest.fixture
def mock_mlflow_client():
with patch('mlflow.tracking.MlflowClient') as MockClient:
mock_create_run = MagicMock(return_value=MagicMock(info=MagicMock(run_id='mock-run-id')))
MockClient.return_value.create_run = mock_create_run
yield MockClient


def test_mlflow_logger_uses_env_var_run_name(monkeypatch, mock_mlflow_client):
"""Test that MLFlowLogger uses the 'RUN_NAME' environment variable if set."""
mlflow = pytest.importorskip('mlflow')

monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
monkeypatch.setattr(mlflow, 'start_run', MagicMock())

from composer.loggers.mlflow_logger import MLFlowLogger
mock_state = MagicMock()
mock_state.run_name = 'dummy-run-name'
monkeypatch.setenv('RUN_NAME', 'env-run-name')

test_logger = MLFlowLogger()
test_logger.init(state=mock_state, logger=MagicMock())

assert test_logger.tags is not None
assert test_logger.tags['run_name'] == 'env-run-name'
monkeypatch.delenv('RUN_NAME')


def test_mlflow_logger_uses_state_run_name_if_no_env_var_set(monkeypatch, mock_mlflow_client):
"""Test that MLFlowLogger uses the state's run name if no 'RUN_NAME' environment variable is set."""
mlflow = pytest.importorskip('mlflow')

monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
monkeypatch.setattr(mlflow, 'start_run', MagicMock())
mock_state = MagicMock()
mock_state.composer_run_name = 'dummy-run-name'
mock_state.run_name = 'state-run-name'

existing_id = 'dummy-id'
mock_search_runs = MagicMock(return_value=[MagicMock(info=MagicMock(run_id=existing_id))])
monkeypatch.setattr(mlflow, 'search_runs', mock_search_runs)

from composer.loggers.mlflow_logger import MLFlowLogger
test_logger = MLFlowLogger()
test_logger.init(state=mock_state, logger=MagicMock())
assert test_logger._run_id == existing_id
assert test_logger.tags is not None
assert test_logger.tags['run_name'] == 'state-run-name'


def test_mlflow_experiment_set_up(tmp_path):
Expand Down

0 comments on commit b44c083

Please sign in to comment.