diff --git a/luigi/contrib/azureblob.py b/luigi/contrib/azureblob.py index 20de24224a..38c96b6973 100644 --- a/luigi/contrib/azureblob.py +++ b/luigi/contrib/azureblob.py @@ -20,7 +20,7 @@ import logging import datetime -from azure.storage.blob import blockblobservice +from azure.storage.blob import BlobServiceClient from luigi.format import get_default_format from luigi.target import FileAlreadyExists, FileSystem, AtomicLocalFile, FileSystemTarget @@ -62,60 +62,101 @@ def __init__(self, account_name=None, account_key=None, sas_token=None, **kwargs * `custom_domain` - The custom domain to use. This can be set in the Azure Portal. For example, ‘www.mydomain.com’. * `token_credential` - A token credential used to authenticate HTTPS requests. The token value should be updated before its expiration. """ - self.options = {"account_name": account_name, "account_key": account_key, "sas_token": sas_token} + if kwargs.get("custom_domain"): + account_url = "{protocol}://{custom_domain}/{account_name}".format(protocol=kwargs.get("protocol", "https"), + custom_domain=kwargs.get("custom_domain"), + account_name=account_name) + else: + account_url = "{protocol}://{account_name}.blob.{endpoint_suffix}".format(protocol=kwargs.get("protocol", + "https"), + account_name=account_name, + endpoint_suffix=kwargs.get( + "endpoint_suffix", + "core.windows.net")) + + self.options = { + "account_name": account_name, + "account_key": account_key, + "account_url": account_url, + "sas_token": sas_token} self.kwargs = kwargs @property def connection(self): - return blockblobservice.BlockBlobService(account_name=self.options.get("account_name"), - account_key=self.options.get("account_key"), - sas_token=self.options.get("sas_token"), - protocol=self.kwargs.get("protocol"), - connection_string=self.kwargs.get("connection_string"), - endpoint_suffix=self.kwargs.get("endpoint_suffix"), - custom_domain=self.kwargs.get("custom_domain"), - is_emulated=self.kwargs.get("is_emulated") or False) + if self.kwargs.get("connection_string"): + return BlobServiceClient.from_connection_string(conn_str=self.kwargs.get("connection_string"), + **self.kwargs) + else: + return BlobServiceClient(account_url=self.options.get("account_url"), + credential=self.options.get("account_key") or self.options.get("sas_token"), + **self.kwargs) + + def container_client(self, container_name): + return self.connection.get_container_client(container_name) + + def blob_client(self, container_name, blob_name): + container_client = self.container_client(container_name) + return container_client.get_blob_client(blob_name) def upload(self, tmp_path, container, blob, **kwargs): logging.debug("Uploading file '{tmp_path}' to container '{container}' and blob '{blob}'".format( tmp_path=tmp_path, container=container, blob=blob)) self.create_container(container) - lease_id = self.connection.acquire_blob_lease(container, blob)\ - if self.exists("{container}/{blob}".format(container=container, blob=blob)) else None + lease = None + blob_client = self.blob_client(container, blob) + if blob_client.exists(): + lease = blob_client.acquire_lease() try: - self.connection.create_blob_from_path(container, blob, tmp_path, lease_id=lease_id, progress_callback=kwargs.get("progress_callback")) + with open(tmp_path, 'rb') as data: + blob_client.upload_blob(data, + overwrite=True, + lease=lease, + progress_hook=kwargs.get("progress_callback")) finally: - if lease_id is not None: - self.connection.release_blob_lease(container, blob, lease_id) + if lease is not None: + lease.release() def download_as_bytes(self, container, blob, bytes_to_read=None): - start_range, end_range = (0, bytes_to_read-1) if bytes_to_read is not None else (None, None) logging.debug("Downloading from container '{container}' and blob '{blob}' as bytes".format( container=container, blob=blob)) - return self.connection.get_blob_to_bytes(container, blob, start_range=start_range, end_range=end_range).content + blob_client = self.blob_client(container, blob) + download_stream = blob_client.download_blob(offset=0, length=bytes_to_read) if bytes_to_read \ + else blob_client.download_blob() + return download_stream.readall() def download_as_file(self, container, blob, location): logging.debug("Downloading from container '{container}' and blob '{blob}' to {location}".format( container=container, blob=blob, location=location)) - return self.connection.get_blob_to_path(container, blob, location) + blob_client = self.blob_client(container, blob) + with open(location, 'wb') as file: + download_stream = blob_client.download_blob() + file.write(download_stream.readall()) + return blob_client.get_blob_properties() def create_container(self, container_name): - return self.connection.create_container(container_name) + if not self.exists(container_name): + return self.connection.create_container(container_name) def delete_container(self, container_name): - lease_id = self.connection.acquire_container_lease(container_name) - self.connection.delete_container(container_name, lease_id=lease_id) + container_client = self.container_client(container_name) + lease = container_client.acquire_lease() + container_client.delete_container(lease=lease) def exists(self, path): container, blob = self.splitfilepath(path) - return self.connection.exists(container, blob) + if blob is None: + return self.container_client(container).exists() + else: + return self.blob_client(container, blob).exists() def remove(self, path, recursive=True, skip_trash=True): - container, blob = self.splitfilepath(path) if not self.exists(path): return False - lease_id = self.connection.acquire_blob_lease(container, blob) - self.connection.delete_blob(container, blob, lease_id=lease_id) + + container, blob = self.splitfilepath(path) + blob_client = self.blob_client(container, blob) + lease = blob_client.acquire_lease() + blob_client.delete_blob(lease=lease) return True def mkdir(self, path, parents=True, raise_if_exists=False): @@ -148,16 +189,18 @@ def copy(self, path, dest): source_container=source_container, dest_container=dest_container )) - source_lease_id = self.connection.acquire_blob_lease(source_container, source_blob) - destination_lease_id = self.connection.acquire_blob_lease(dest_container, dest_blob) if self.exists(dest) else None + source_blob_client = self.blob_client(source_container, source_blob) + dest_blob_client = self.blob_client(dest_container, dest_blob) + source_lease = source_blob_client.acquire_lease() + destination_lease = dest_blob_client.acquire_lease() if self.exists(dest) else None try: - return self.connection.copy_blob(source_container, dest_blob, self.connection.make_blob_url( - source_container, source_blob), - destination_lease_id=destination_lease_id, source_lease_id=source_lease_id) + return dest_blob_client.start_copy_from_url(source_url=source_blob_client.url, + source_lease=source_lease, + destination_lease=destination_lease) finally: - self.connection.release_blob_lease(source_container, source_blob, source_lease_id) - if destination_lease_id is not None: - self.connection.release_blob_lease(dest_container, dest_blob, destination_lease_id) + source_lease.release() + if destination_lease is not None: + destination_lease.release() def rename_dont_move(self, path, dest): self.move(path, dest) diff --git a/test/contrib/azureblob_test.py b/test/contrib/azureblob_test.py index d587768c2e..0706c90caa 100644 --- a/test/contrib/azureblob_test.py +++ b/test/contrib/azureblob_test.py @@ -30,8 +30,9 @@ account_name = os.environ.get("ACCOUNT_NAME") account_key = os.environ.get("ACCOUNT_KEY") sas_token = os.environ.get("SAS_TOKEN") -is_emulated = False if account_name else True -client = AzureBlobClient(account_name, account_key, sas_token, is_emulated=is_emulated) +custom_domain = os.environ.get("CUSTOM_DOMAIN") +protocol = os.environ.get("PROTOCOL") +client = AzureBlobClient(account_name, account_key, sas_token, custom_domain=custom_domain, protocol=protocol) @pytest.mark.azureblob @@ -96,7 +97,7 @@ def test_upload_copy_move_remove_blob(self): self.assertTrue(self.client.exists(from_path)) # copy - self.assertIn(self.client.copy(from_path, to_path).status, ["success", "pending"]) + self.assertIn(self.client.copy(from_path, to_path)["copy_status"], ["success", "pending"]) self.assertTrue(self.client.exists(to_path)) # remove @@ -121,7 +122,7 @@ def output(self): return AzureBlobTarget("luigi-test", "movie-cheesy.txt", client, download_when_reading=False) def run(self): - client.connection.create_container("luigi-test") + client.create_container("luigi-test") with self.output().open("w") as op: op.write("I'm going to make him an offer he can't refuse.\n") op.write("Toto, I've got a feeling we're not in Kansas anymore.\n")