Skip to content

Commit

Permalink
more mocks
Browse files Browse the repository at this point in the history
  • Loading branch information
KuuCi committed Apr 16, 2024
1 parent eb9896d commit a96840f
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions tests/loggers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,43 +190,47 @@ def test_mlflow_experiment_init_existing_composer_run(monkeypatch):
assert test_logger._run_id == existing_id


def test_mlflow_logger_uses_env_var_run_name(monkeypatch):
@pytest.fixture
def mock_mlflow_client():
with patch('mlflow.tracking.MlflowClient') as MockClient:
mock_create_run = MagicMock(return_value=MagicMock(info=MagicMock(run_id='mock-run-id')))
MockClient.return_value.create_run = mock_create_run
yield MockClient


def test_mlflow_logger_uses_env_var_run_name(monkeypatch, mock_mlflow_client):
"""Test that MLFlowLogger uses the 'RUN_NAME' environment variable if set."""
mlflow = pytest.importorskip('mlflow')

monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
monkeypatch.setattr(mlflow, 'start_run', MagicMock())
mock_create_run = MagicMock(return_value=MagicMock(info=MagicMock(run_id='mock-run-id')))

with patch('mlflow.tracking.MlflowClient') as MockClient:
MockClient.return_value.create_run = mock_create_run

from composer.loggers.mlflow_logger import MLFlowLogger
mock_state = MagicMock()
mock_state.run_name = 'dummy-run-name'
monkeypatch.setenv('RUN_NAME', 'env-run-name')
from composer.loggers.mlflow_logger import MLFlowLogger
mock_state = MagicMock()
mock_state.run_name = 'dummy-run-name'
monkeypatch.setenv('RUN_NAME', 'env-run-name')

test_logger = MLFlowLogger()
test_logger.init(state=mock_state, logger=MagicMock())
test_logger = MLFlowLogger()
test_logger.init(state=mock_state, logger=MagicMock())

assert test_logger.tags['run_name'] == 'env-run-name', "Logger should use the run name from the environment variable."
monkeypatch.delenv('RUN_NAME')
assert test_logger.tags['run_name'] == 'env-run-name', "Logger should use the run name from the environment variable."
monkeypatch.delenv('RUN_NAME')


def test_mlflow_logger_uses_state_run_name_if_no_env_var_set(monkeypatch):
def test_mlflow_logger_uses_state_run_name_if_no_env_var_set(monkeypatch, mock_mlflow_client):
"""Test that MLFlowLogger uses the state's run name if no 'RUN_NAME' environment variable is set."""
mlflow = pytest.importorskip('mlflow')

monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
monkeypatch.setattr(mlflow, 'start_run', MagicMock())

mock_state = MagicMock()
mock_state.run_name = 'state-run-name'

existing_id = 'dummy-id'
mock_search_runs = MagicMock(return_value=[MagicMock(info=MagicMock(run_id=existing_id))])
monkeypatch.setattr(mlflow, 'search_runs', mock_search_runs)

from composer.loggers.mlflow_logger import MLFlowLogger
test_logger = MLFlowLogger()
test_logger.init(state=mock_state, logger=MagicMock())
assert test_logger.tags['run_name'] == 'state-run-name'
Expand Down

0 comments on commit a96840f

Please sign in to comment.