Skip to content

Commit

Permalink
Properly propagate constructor S3 endpoint URL in S3ArtifactRepository (
Browse files Browse the repository at this point in the history
mlflow#9593)

Signed-off-by: Jerry Liang <[email protected]>
  • Loading branch information
jerrylian-db authored Sep 11, 2023
1 parent e4967ab commit 59880b7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
1 change: 1 addition & 0 deletions mlflow/store/artifact/s3_artifact_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def _get_s3_client(self, addressing_style="path", s3_endpoint_url=None):
access_key_id=self._access_key_id,
secret_access_key=self._secret_access_key,
session_token=self._session_token,
s3_endpoint_url=s3_endpoint_url,
)

def parse_s3_compliant_uri(self, uri):
Expand Down
18 changes: 18 additions & 0 deletions tests/store/artifact/test_r2_artifact_repo.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import posixpath
from unittest import mock
from unittest.mock import ANY

import pytest

Expand Down Expand Up @@ -40,3 +42,19 @@ def test_convert_r2_uri_to_s3_endpoint_url(r2_artifact_root):

s3_endpoint_url = repo.convert_r2_uri_to_s3_endpoint_url(r2_artifact_root)
assert s3_endpoint_url == "https://account.r2.cloudflarestorage.com"


def test_s3_endpoint_url_is_used_to_get_s3_client(r2_artifact_root):
with mock.patch("boto3.client") as mock_get_s3_client:
artifact_uri = posixpath.join(r2_artifact_root, "some/path")
repo = get_artifact_repository(artifact_uri)
repo._get_s3_client()
mock_get_s3_client.assert_called_with(
"s3",
config=ANY,
endpoint_url="https://account.r2.cloudflarestorage.com",
verify=None,
aws_access_key_id=None,
aws_secret_access_key=None,
aws_session_token=None,
)

0 comments on commit 59880b7

Please sign in to comment.