From a96840f2a1200cf0cf977ddc2b949682036d9c9c Mon Sep 17 00:00:00 2001 From: Vincent Chen Date: Tue, 16 Apr 2024 11:09:16 -0700 Subject: [PATCH] more mocks --- tests/loggers/test_mlflow_logger.py | 34 ++++++++++++++++------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/tests/loggers/test_mlflow_logger.py b/tests/loggers/test_mlflow_logger.py index 4f784c479a..43193885b3 100644 --- a/tests/loggers/test_mlflow_logger.py +++ b/tests/loggers/test_mlflow_logger.py @@ -190,36 +190,39 @@ def test_mlflow_experiment_init_existing_composer_run(monkeypatch): assert test_logger._run_id == existing_id -def test_mlflow_logger_uses_env_var_run_name(monkeypatch): +@pytest.fixture +def mock_mlflow_client(): + with patch('mlflow.tracking.MlflowClient') as MockClient: + mock_create_run = MagicMock(return_value=MagicMock(info=MagicMock(run_id='mock-run-id'))) + MockClient.return_value.create_run = mock_create_run + yield MockClient + + +def test_mlflow_logger_uses_env_var_run_name(monkeypatch, mock_mlflow_client): """Test that MLFlowLogger uses the 'RUN_NAME' environment variable if set.""" mlflow = pytest.importorskip('mlflow') monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock()) monkeypatch.setattr(mlflow, 'start_run', MagicMock()) - mock_create_run = MagicMock(return_value=MagicMock(info=MagicMock(run_id='mock-run-id'))) - with patch('mlflow.tracking.MlflowClient') as MockClient: - MockClient.return_value.create_run = mock_create_run - - from composer.loggers.mlflow_logger import MLFlowLogger - mock_state = MagicMock() - mock_state.run_name = 'dummy-run-name' - monkeypatch.setenv('RUN_NAME', 'env-run-name') + from composer.loggers.mlflow_logger import MLFlowLogger + mock_state = MagicMock() + mock_state.run_name = 'dummy-run-name' + monkeypatch.setenv('RUN_NAME', 'env-run-name') - test_logger = MLFlowLogger() - test_logger.init(state=mock_state, logger=MagicMock()) + test_logger = MLFlowLogger() + test_logger.init(state=mock_state, logger=MagicMock()) - assert test_logger.tags['run_name'] == 'env-run-name', "Logger should use the run name from the environment variable." - monkeypatch.delenv('RUN_NAME') + assert test_logger.tags['run_name'] == 'env-run-name', "Logger should use the run name from the environment variable." + monkeypatch.delenv('RUN_NAME') -def test_mlflow_logger_uses_state_run_name_if_no_env_var_set(monkeypatch): +def test_mlflow_logger_uses_state_run_name_if_no_env_var_set(monkeypatch, mock_mlflow_client): """Test that MLFlowLogger uses the state's run name if no 'RUN_NAME' environment variable is set.""" mlflow = pytest.importorskip('mlflow') monkeypatch.setattr(mlflow, 'set_tracking_uri', MagicMock()) monkeypatch.setattr(mlflow, 'start_run', MagicMock()) - mock_state = MagicMock() mock_state.run_name = 'state-run-name' @@ -227,6 +230,7 @@ def test_mlflow_logger_uses_state_run_name_if_no_env_var_set(monkeypatch): mock_search_runs = MagicMock(return_value=[MagicMock(info=MagicMock(run_id=existing_id))]) monkeypatch.setattr(mlflow, 'search_runs', mock_search_runs) + from composer.loggers.mlflow_logger import MLFlowLogger test_logger = MLFlowLogger() test_logger.init(state=mock_state, logger=MagicMock()) assert test_logger.tags['run_name'] == 'state-run-name'