From df74553ec484ad729fcd75ccbc1f5f18e7f34dc8 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Wed, 2 Aug 2023 07:14:07 +0100 Subject: [PATCH] Refactor account_url use in WasbHook (#32980) * Refactor account_url use in WasbHook This PR moves the account_url setting to one place. Tested this by making connection to azure using the different methods, however, I was not able to connect using the tenant_id in the extra field. This looks like a bug because ClientSecretCredential is not among the credentials to use in BlobServiceClient. The credentials to use include AzureNamedKeyCredential,AzureSasCredential,AsyncTokenCredential. So this will need special debugging. * fixup! Refactor account_url use in WasbHook --- .../providers/microsoft/azure/hooks/wasb.py | 32 ++++++++++--------- .../microsoft/azure/hooks/test_wasb.py | 2 +- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/wasb.py b/airflow/providers/microsoft/azure/hooks/wasb.py index 6c4b76cdc47b..d129cc4ce595 100644 --- a/airflow/providers/microsoft/azure/hooks/wasb.py +++ b/airflow/providers/microsoft/azure/hooks/wasb.py @@ -132,7 +132,7 @@ def get_ui_field_behaviour() -> dict[str, Any]: "relabeling": { "login": "Blob Storage Login (optional)", "password": "Blob Storage Key (optional)", - "host": "Account Name (Active Directory Auth)", + "host": "Account URL (Active Directory Auth)", }, "placeholders": { "login": "account name", @@ -154,7 +154,7 @@ def __init__( super().__init__() self.conn_id = wasb_conn_id self.public_read = public_read - self.blob_service_client = self.get_conn() + self.blob_service_client: BlobServiceClient = self.get_conn() logger = logging.getLogger("azure.core.pipeline.policies.http_logging_policy") try: @@ -184,15 +184,19 @@ def get_conn(self) -> BlobServiceClient: # connection_string auth takes priority return BlobServiceClient.from_connection_string(connection_string, **extra) + account_url = ( + conn.host + if conn.host and conn.host.startswith("https://") + else f"https://{conn.login}.blob.core.windows.net/" + ) + tenant = self._get_field(extra, "tenant_id") if tenant: # use Active Directory auth app_id = conn.login app_secret = conn.password token_credential = ClientSecretCredential(tenant, app_id, app_secret, **client_secret_auth_config) - return BlobServiceClient(account_url=conn.host, credential=token_credential, **extra) - - account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/" + return BlobServiceClient(account_url=account_url, credential=token_credential, **extra) if self.public_read: # Here we use anonymous public read @@ -210,9 +214,6 @@ def get_conn(self) -> BlobServiceClient: if sas_token.startswith("https"): return BlobServiceClient(account_url=sas_token, **extra) else: - if not account_url.startswith("https://"): - # TODO: require url in the host field in the next major version? - account_url = f"https://{conn.login}.blob.core.windows.net" return BlobServiceClient(account_url=f"{account_url.rstrip('/')}/{sas_token}", **extra) # Fall back to old auth (password) or use managed identity if not provided. @@ -220,9 +221,6 @@ def get_conn(self) -> BlobServiceClient: if not credential: credential = DefaultAzureCredential() self.log.info("Using DefaultAzureCredential as credential") - if not account_url.startswith("https://"): - # TODO: require url in the host field in the next major version? - account_url = f"https://{conn.login}.blob.core.windows.net/" return BlobServiceClient( account_url=account_url, credential=credential, @@ -589,6 +587,12 @@ async def get_async_conn(self) -> AsyncBlobServiceClient: ) return self.blob_service_client + account_url = ( + conn.host + if conn.host and conn.host.startswith("https://") + else f"https://{conn.login}.blob.core.windows.net/" + ) + tenant = self._get_field(extra, "tenant_id") if tenant: # use Active Directory auth @@ -598,12 +602,10 @@ async def get_async_conn(self) -> AsyncBlobServiceClient: tenant, app_id, app_secret, **client_secret_auth_config ) self.blob_service_client = AsyncBlobServiceClient( - account_url=conn.host, credential=token_credential, **extra # type:ignore[arg-type] + account_url=account_url, credential=token_credential, **extra # type:ignore[arg-type] ) return self.blob_service_client - account_url = conn.host if conn.host else f"https://{conn.login}.blob.core.windows.net/" - if self.public_read: # Here we use anonymous public read # more info @@ -625,7 +627,7 @@ async def get_async_conn(self) -> AsyncBlobServiceClient: self.blob_service_client = AsyncBlobServiceClient(account_url=sas_token, **extra) else: self.blob_service_client = AsyncBlobServiceClient( - account_url=f"{account_url}/{sas_token}", **extra + account_url=f"{account_url.rstrip('/')}/{sas_token}", **extra ) return self.blob_service_client diff --git a/tests/providers/microsoft/azure/hooks/test_wasb.py b/tests/providers/microsoft/azure/hooks/test_wasb.py index 5deca1b80553..52826781bba7 100644 --- a/tests/providers/microsoft/azure/hooks/test_wasb.py +++ b/tests/providers/microsoft/azure/hooks/test_wasb.py @@ -223,7 +223,7 @@ def test_azure_directory_connection(self, mock_get_conn, mock_credential, mock_b authority=self.client_secret_auth_config["authority"], ) mock_blob_service_client.assert_called_once_with( - account_url=conn.host, + account_url=f"https://{conn.login}.blob.core.windows.net/", credential=mock_credential.return_value, tenant_id=conn.extra_dejson["tenant_id"], proxies=conn.extra_dejson["proxies"],