Skip to content

Commit

Permalink
patch
Browse files Browse the repository at this point in the history
  • Loading branch information
KuuCi committed Apr 16, 2024
1 parent 0f77d5e commit a71bdd2
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions tests/loggers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import time
from pathlib import Path
from unittest.mock import MagicMock
from unittest.mock import patch, MagicMock

import numpy as np
import pytest
Expand Down Expand Up @@ -192,28 +192,45 @@ def test_mlflow_experiment_init_existing_composer_run(monkeypatch):

def test_mlflow_logger_uses_env_var_run_name(monkeypatch):
"""Test that MLFlowLogger uses the 'RUN_NAME' environment variable if set."""
import mlflow
mlflow = pytest.importorskip('mlflow')
MlflowClient = mlflow.tracking.MlflowClient # Import locally after ensuring mlflow is available

# Set up mocks
monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
monkeypatch.setattr(mlflow, 'start_run', MagicMock())
mock_create_run = MagicMock(return_value=MagicMock(info=MagicMock(run_id='mock-run-id')))

# Use 'patch' to mock MlflowClient inside the test function where it's used
with patch('mlflow.tracking.MlflowClient') as MockClient:
MockClient.return_value.create_run = mock_create_run

mock_state = MagicMock()
mock_state.run_name = 'dummy-run-name'
monkeypatch.setenv('RUN_NAME', 'env-run-name')
from composer.loggers.mlflow_logger import MLFlowLogger # Import the logger here after the patch
mock_state = MagicMock()
mock_state.run_name = 'dummy-run-name'
monkeypatch.setenv('RUN_NAME', 'env-run-name')

test_logger = MLFlowLogger()
test_logger.init(state=mock_state, logger=MagicMock())

# Assertion to check the correct environment variable is used
assert test_logger.tags['run_name'] == 'env-run-name', "Logger should use the run name from the environment variable."
monkeypatch.delenv('RUN_NAME')

test_logger = MLFlowLogger()
test_logger.init(state=mock_state, logger=MagicMock())
assert test_logger.tags['run_name'] == 'env-run-name'
monkeypatch.delenv('RUN_NAME')

def test_mlflow_logger_uses_state_run_name_if_no_env_var_set(monkeypatch):
"""Test that MLFlowLogger uses the state's run name if no 'RUN_NAME' environment variable is set."""
import mlflow
mlflow = pytest.importorskip('mlflow')

monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock())
monkeypatch.setattr(mlflow, 'start_run', MagicMock())

mock_state = MagicMock()
mock_state.run_name = 'state-run-name'

existing_id = 'dummy-id'
mock_search_runs = MagicMock(return_value=[MagicMock(info=MagicMock(run_id=existing_id))])
monkeypatch.setattr(mlflow, 'search_runs', mock_search_runs)

test_logger = MLFlowLogger()
test_logger.init(state=mock_state, logger=MagicMock())
assert test_logger.tags['run_name'] == 'state-run-name'
Expand Down

0 comments on commit a71bdd2

Please sign in to comment.