From 93340f7e7cf57dfeba98ffa98c647c37a7df996d Mon Sep 17 00:00:00 2001 From: Timur Sadykov Date: Fri, 18 Oct 2024 08:49:58 -0700 Subject: [PATCH] fix: test updates --- google/auth/iam.py | 10 ++--- google/auth/impersonated_credentials.py | 25 +++++++++---- google/oauth2/_client.py | 4 +- google/oauth2/service_account.py | 1 + tests/oauth2/test__client.py | 2 + tests/oauth2/test_service_account.py | 8 ++-- tests/test_impersonated_credentials.py | 50 ++++++++++++++++--------- 7 files changed, 65 insertions(+), 35 deletions(-) diff --git a/google/auth/iam.py b/google/auth/iam.py index 7367f02fc..bed1930f5 100644 --- a/google/auth/iam.py +++ b/google/auth/iam.py @@ -43,13 +43,11 @@ ) _IAM_SIGN_ENDPOINT = ( - "https://iamcredentials.{}/v1/projects/-" - + "/serviceAccounts/{}:signBlob" + "https://iamcredentials.{}/v1/projects/-" + "/serviceAccounts/{}:signBlob" ) _IAM_IDTOKEN_ENDPOINT = ( - "https://iamcredentials.{}/v1/" - + "projects/-/serviceAccounts/{}:generateIdToken" + "https://iamcredentials.{}/v1/" + "projects/-/serviceAccounts/{}:generateIdToken" ) @@ -89,7 +87,9 @@ def _make_signing_request(self, message): message = _helpers.to_bytes(message) method = "POST" - url = _IAM_SIGN_ENDPOINT.format(self._credentials.universe_domain, self._service_account_email) + url = _IAM_SIGN_ENDPOINT.format( + self._credentials.universe_domain, self._service_account_email + ) headers = {"Content-Type": "application/json"} body = json.dumps( {"payload": base64.b64encode(message).decode("utf-8")} diff --git a/google/auth/impersonated_credentials.py b/google/auth/impersonated_credentials.py index defce023b..9443c593c 100644 --- a/google/auth/impersonated_credentials.py +++ b/google/auth/impersonated_credentials.py @@ -41,6 +41,11 @@ _REFRESH_ERROR = "Unable to acquire impersonated credentials" +_UNIVERSE_DOMAIN_MATCH_SOURCE_ERROR = ( + "The universe_domain " + "is not supported for impersonated credentials. The " + "credential uses the value from source_credentials." +) _DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds @@ -67,7 +72,9 @@ def _make_iam_token_request( `iamcredentials.googleapis.com` is not enabled or the `Service Account Token Creator` is not assigned """ - iam_endpoint = iam_endpoint_override or iam._IAM_ENDPOINT.format(universe_domain, principal) + iam_endpoint = iam_endpoint_override or iam._IAM_ENDPOINT.format( + universe_domain, principal + ) body = json.dumps(body).encode("utf-8") @@ -173,6 +180,7 @@ def __init__( lifetime=_DEFAULT_TOKEN_LIFETIME_SECS, quota_project_id=None, iam_endpoint_override=None, + universe_domain=None, ): """ Args: @@ -219,8 +227,9 @@ def __init__( and self._source_credentials._always_use_jwt_access ): self._source_credentials._create_self_signed_jwt(None) - if (self.universe_domain != self._source_credentials.universe_domain): - self._universe_domain = source_credentials.universe_domain + if (universe_domain is not None and universe_domain != self._source_credentials.universe_domain): + raise exceptions.InvalidOperation(_UNIVERSE_DOMAIN_MATCH_SOURCE_ERROR) + self._universe_domain = source_credentials.universe_domain self._target_principal = target_principal self._target_scopes = target_scopes self._delegates = delegates @@ -276,14 +285,16 @@ def _update_token(self, request): universe_domain=self.universe_domain, iam_endpoint_override=self._iam_endpoint_override, ) - + def get_iam_sign_endpoint(self): - return iam._IAM_SIGN_ENDPOINT.format(self.universe_domain, self._target_principal) + return iam._IAM_SIGN_ENDPOINT.format( + self.universe_domain, self._target_principal + ) def sign_bytes(self, message): from google.auth.transport.requests import AuthorizedSession - iam_sign_endpoint = self.get_iam_sign_endpoint(self) + iam_sign_endpoint = self.get_iam_sign_endpoint() body = { "payload": base64.b64encode(message).decode("utf-8"), @@ -435,7 +446,7 @@ def refresh(self, request): iam_sign_endpoint = iam._IAM_IDTOKEN_ENDPOINT.format( self._target_credentials.universe_domain, - self._target_credentials.signer_email + self._target_credentials.signer_email, ) body = { diff --git a/google/oauth2/_client.py b/google/oauth2/_client.py index 68e13ddc7..98446c53a 100644 --- a/google/oauth2/_client.py +++ b/google/oauth2/_client.py @@ -319,7 +319,7 @@ def jwt_grant(request, token_uri, assertion, can_retry=True): def call_iam_generate_id_token_endpoint( - request, iam_id_token_endpoint, signer_email, audience, access_token +request, iam_id_token_endpoint, signer_email, audience, access_token, universe_domain ): """Call iam.generateIdToken endpoint to get ID token. @@ -339,7 +339,7 @@ def call_iam_generate_id_token_endpoint( response_data = _token_endpoint_request( request, - iam_id_token_endpoint.format(signer_email), + iam_id_token_endpoint.format(universe_domain, signer_email), body, access_token=access_token, use_json=True, diff --git a/google/oauth2/service_account.py b/google/oauth2/service_account.py index 98dafa3e3..3e84194ac 100644 --- a/google/oauth2/service_account.py +++ b/google/oauth2/service_account.py @@ -812,6 +812,7 @@ def _refresh_with_iam_endpoint(self, request): self.signer_email, self._target_audience, jwt_credentials.token.decode(), + self._universe_domain, ) @_helpers.copy_docstring(credentials.Credentials) diff --git a/tests/oauth2/test__client.py b/tests/oauth2/test__client.py index 9da63cbde..50ffa7510 100644 --- a/tests/oauth2/test__client.py +++ b/tests/oauth2/test__client.py @@ -324,6 +324,7 @@ def test_call_iam_generate_id_token_endpoint(): "fake_email", "fake_audience", "fake_access_token", + "googleapis.com", ) assert ( @@ -361,6 +362,7 @@ def test_call_iam_generate_id_token_endpoint_no_id_token(): "fake_email", "fake_audience", "fake_access_token", + "googleapis.com" ) assert excinfo.match("No ID token in response") diff --git a/tests/oauth2/test_service_account.py b/tests/oauth2/test_service_account.py index 2c3fea5b2..45e0d6c91 100644 --- a/tests/oauth2/test_service_account.py +++ b/tests/oauth2/test_service_account.py @@ -789,7 +789,7 @@ def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint): ) request = mock.Mock() credentials.refresh(request) - req, iam_endpoint, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[ + req, iam_endpoint, signer_email, target_audience, access_token, universe_domain = call_iam_generate_id_token_endpoint.call_args[ 0 ] assert req == request @@ -798,6 +798,7 @@ def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint): assert target_audience == "https://example.com" decoded_access_token = jwt.decode(access_token, verify=False) assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam" + assert universe_domain == "googleapis.com" @mock.patch( "google.oauth2._client.call_iam_generate_id_token_endpoint", autospec=True @@ -811,18 +812,19 @@ def test_refresh_iam_flow_non_gdu(self, call_iam_generate_id_token_endpoint): ) request = mock.Mock() credentials.refresh(request) - req, iam_endpoint, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[ + req, iam_endpoint, signer_email, target_audience, access_token, universe_domain = call_iam_generate_id_token_endpoint.call_args[ 0 ] assert req == request assert ( iam_endpoint - == "https://iamcredentials.fake-universe/v1/projects/-/serviceAccounts/{}:generateIdToken" + == "https://iamcredentials.{}/v1/projects/-/serviceAccounts/{}:generateIdToken" ) assert signer_email == "service-account@example.com" assert target_audience == "https://example.com" decoded_access_token = jwt.decode(access_token, verify=False) assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam" + assert universe_domain == "fake-universe" @mock.patch("google.oauth2._client.id_token_jwt_grant", autospec=True) def test_before_request_refreshes(self, id_token_jwt_grant): diff --git a/tests/test_impersonated_credentials.py b/tests/test_impersonated_credentials.py index 696401a38..f80339475 100644 --- a/tests/test_impersonated_credentials.py +++ b/tests/test_impersonated_credentials.py @@ -124,7 +124,7 @@ def make_credentials( lifetime=LIFETIME, target_principal=TARGET_PRINCIPAL, iam_endpoint_override=None, - universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN + universe_domain=None ): return Credentials( @@ -146,22 +146,27 @@ def test_get_cred_info(self): "credential_source": "/path/to/file", "credential_type": "impersonated credentials", "principal": "impersonated@project.iam.gserviceaccount.com", - "iam_endpoint_override": None, } - def test_get_cred_info_universe_domain(self): - credentials = self.make_credentials(universe_domain="foo.bar") - assert not credentials.get_cred_info() - - credentials._cred_file_path = "/path/to/file" - assert credentials.get_cred_info() == { - "credential_source": "/path/to/file", - "credential_type": "impersonated credentials", - "principal": "impersonated@project.iam.gserviceaccount.com", - "universe_domain": "foo.bar", - "iam_endpoint_override": "https://iamcredentials.foo.bar/v1/projects/-" - + "/serviceAccounts/impersonated@project.iam.gserviceaccount.com:generateAccessToken" - } + def test_explicit_universe_domain_matching_source(self): + source_credentials = service_account.Credentials( + SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar" + ) + credentials = self.make_credentials(universe_domain="foo.bar", source_credentials=source_credentials) + assert credentials.universe_domain == "foo.bar" + + def test_universe_domain_from_source(self): + source_credentials = service_account.Credentials( + SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar" + ) + credentials = self.make_credentials(source_credentials=source_credentials) + assert credentials.universe_domain == "foo.bar" + + def test_explicit_universe_domain_not_matching_source(self): + with pytest.raises(exceptions.InvalidOperation) as excinfo: + self.make_credentials(universe_domain="foo.bar") + + assert excinfo.match(impersonated_credentials._UNIVERSE_DOMAIN_MATCH_SOURCE_ERROR) def test__make_copy_get_cred_info(self): credentials = self.make_credentials() @@ -409,9 +414,18 @@ def test_signer_email(self): assert credentials.signer_email == self.TARGET_PRINCIPAL def test_sign_endpoint(self): - credentials = self.make_credentials(universe_domain="foo.bar") - assert credentials.get_iam_sign_endpoint == "https://iamcredentials.foo.bar/v1/projects/-" - + "/serviceAccounts/impersonated@project.iam.gserviceaccount.com:signBlob" + source_credentials = service_account.Credentials( + SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar" + ) + credentials = self.make_credentials(source_credentials=source_credentials) + assert credentials.get_iam_sign_endpoint() == "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/impersonated@project.iam.gserviceaccount.com:signBlob" + + def test_sign_endpoint_explicit_universe_domain(self): + source_credentials = service_account.Credentials( + SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar" + ) + credentials = self.make_credentials(universe_domain="foo.bar", source_credentials=source_credentials) + assert credentials.get_iam_sign_endpoint() == "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/impersonated@project.iam.gserviceaccount.com:signBlob" def test_service_account_email(self): credentials = self.make_credentials(target_principal=self.TARGET_PRINCIPAL)