Skip to content

Commit

Permalink
fix: test updates
Browse files Browse the repository at this point in the history
  • Loading branch information
TimurSadykov committed Oct 18, 2024
1 parent 17d1d1b commit 93340f7
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 35 deletions.
10 changes: 5 additions & 5 deletions google/auth/iam.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,11 @@
)

_IAM_SIGN_ENDPOINT = (
"https://iamcredentials.{}/v1/projects/-"
+ "/serviceAccounts/{}:signBlob"
"https://iamcredentials.{}/v1/projects/-" + "/serviceAccounts/{}:signBlob"
)

_IAM_IDTOKEN_ENDPOINT = (
"https://iamcredentials.{}/v1/"
+ "projects/-/serviceAccounts/{}:generateIdToken"
"https://iamcredentials.{}/v1/" + "projects/-/serviceAccounts/{}:generateIdToken"
)


Expand Down Expand Up @@ -89,7 +87,9 @@ def _make_signing_request(self, message):
message = _helpers.to_bytes(message)

method = "POST"
url = _IAM_SIGN_ENDPOINT.format(self._credentials.universe_domain, self._service_account_email)
url = _IAM_SIGN_ENDPOINT.format(
self._credentials.universe_domain, self._service_account_email
)
headers = {"Content-Type": "application/json"}
body = json.dumps(
{"payload": base64.b64encode(message).decode("utf-8")}
Expand Down
25 changes: 18 additions & 7 deletions google/auth/impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@


_REFRESH_ERROR = "Unable to acquire impersonated credentials"
_UNIVERSE_DOMAIN_MATCH_SOURCE_ERROR = (
"The universe_domain "
"is not supported for impersonated credentials. The "
"credential uses the value from source_credentials."
)

_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds

Expand All @@ -67,7 +72,9 @@ def _make_iam_token_request(
`iamcredentials.googleapis.com` is not enabled or the
`Service Account Token Creator` is not assigned
"""
iam_endpoint = iam_endpoint_override or iam._IAM_ENDPOINT.format(universe_domain, principal)
iam_endpoint = iam_endpoint_override or iam._IAM_ENDPOINT.format(
universe_domain, principal
)

body = json.dumps(body).encode("utf-8")

Expand Down Expand Up @@ -173,6 +180,7 @@ def __init__(
lifetime=_DEFAULT_TOKEN_LIFETIME_SECS,
quota_project_id=None,
iam_endpoint_override=None,
universe_domain=None,
):
"""
Args:
Expand Down Expand Up @@ -219,8 +227,9 @@ def __init__(
and self._source_credentials._always_use_jwt_access
):
self._source_credentials._create_self_signed_jwt(None)
if (self.universe_domain != self._source_credentials.universe_domain):
self._universe_domain = source_credentials.universe_domain
if (universe_domain is not None and universe_domain != self._source_credentials.universe_domain):
raise exceptions.InvalidOperation(_UNIVERSE_DOMAIN_MATCH_SOURCE_ERROR)
self._universe_domain = source_credentials.universe_domain
self._target_principal = target_principal
self._target_scopes = target_scopes
self._delegates = delegates
Expand Down Expand Up @@ -276,14 +285,16 @@ def _update_token(self, request):
universe_domain=self.universe_domain,
iam_endpoint_override=self._iam_endpoint_override,
)

def get_iam_sign_endpoint(self):
return iam._IAM_SIGN_ENDPOINT.format(self.universe_domain, self._target_principal)
return iam._IAM_SIGN_ENDPOINT.format(
self.universe_domain, self._target_principal
)

def sign_bytes(self, message):
from google.auth.transport.requests import AuthorizedSession

iam_sign_endpoint = self.get_iam_sign_endpoint(self)
iam_sign_endpoint = self.get_iam_sign_endpoint()

body = {
"payload": base64.b64encode(message).decode("utf-8"),
Expand Down Expand Up @@ -435,7 +446,7 @@ def refresh(self, request):

iam_sign_endpoint = iam._IAM_IDTOKEN_ENDPOINT.format(
self._target_credentials.universe_domain,
self._target_credentials.signer_email
self._target_credentials.signer_email,
)

body = {
Expand Down
4 changes: 2 additions & 2 deletions google/oauth2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def jwt_grant(request, token_uri, assertion, can_retry=True):


def call_iam_generate_id_token_endpoint(
request, iam_id_token_endpoint, signer_email, audience, access_token
request, iam_id_token_endpoint, signer_email, audience, access_token, universe_domain
):
"""Call iam.generateIdToken endpoint to get ID token.
Expand All @@ -339,7 +339,7 @@ def call_iam_generate_id_token_endpoint(

response_data = _token_endpoint_request(
request,
iam_id_token_endpoint.format(signer_email),
iam_id_token_endpoint.format(universe_domain, signer_email),
body,
access_token=access_token,
use_json=True,
Expand Down
1 change: 1 addition & 0 deletions google/oauth2/service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,7 @@ def _refresh_with_iam_endpoint(self, request):
self.signer_email,
self._target_audience,
jwt_credentials.token.decode(),
self._universe_domain,
)

@_helpers.copy_docstring(credentials.Credentials)
Expand Down
2 changes: 2 additions & 0 deletions tests/oauth2/test__client.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def test_call_iam_generate_id_token_endpoint():
"fake_email",
"fake_audience",
"fake_access_token",
"googleapis.com",
)

assert (
Expand Down Expand Up @@ -361,6 +362,7 @@ def test_call_iam_generate_id_token_endpoint_no_id_token():
"fake_email",
"fake_audience",
"fake_access_token",
"googleapis.com"
)
assert excinfo.match("No ID token in response")

Expand Down
8 changes: 5 additions & 3 deletions tests/oauth2/test_service_account.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint):
)
request = mock.Mock()
credentials.refresh(request)
req, iam_endpoint, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[
req, iam_endpoint, signer_email, target_audience, access_token, universe_domain = call_iam_generate_id_token_endpoint.call_args[
0
]
assert req == request
Expand All @@ -798,6 +798,7 @@ def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint):
assert target_audience == "https://example.com"
decoded_access_token = jwt.decode(access_token, verify=False)
assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam"
assert universe_domain == "googleapis.com"

@mock.patch(
"google.oauth2._client.call_iam_generate_id_token_endpoint", autospec=True
Expand All @@ -811,18 +812,19 @@ def test_refresh_iam_flow_non_gdu(self, call_iam_generate_id_token_endpoint):
)
request = mock.Mock()
credentials.refresh(request)
req, iam_endpoint, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[
req, iam_endpoint, signer_email, target_audience, access_token, universe_domain = call_iam_generate_id_token_endpoint.call_args[
0
]
assert req == request
assert (
iam_endpoint
== "https://iamcredentials.fake-universe/v1/projects/-/serviceAccounts/{}:generateIdToken"
== "https://iamcredentials.{}/v1/projects/-/serviceAccounts/{}:generateIdToken"
)
assert signer_email == "[email protected]"
assert target_audience == "https://example.com"
decoded_access_token = jwt.decode(access_token, verify=False)
assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam"
assert universe_domain == "fake-universe"

@mock.patch("google.oauth2._client.id_token_jwt_grant", autospec=True)
def test_before_request_refreshes(self, id_token_jwt_grant):
Expand Down
50 changes: 32 additions & 18 deletions tests/test_impersonated_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def make_credentials(
lifetime=LIFETIME,
target_principal=TARGET_PRINCIPAL,
iam_endpoint_override=None,
universe_domain=credentials.DEFAULT_UNIVERSE_DOMAIN
universe_domain=None
):

return Credentials(
Expand All @@ -146,22 +146,27 @@ def test_get_cred_info(self):
"credential_source": "/path/to/file",
"credential_type": "impersonated credentials",
"principal": "[email protected]",
"iam_endpoint_override": None,
}

def test_get_cred_info_universe_domain(self):
credentials = self.make_credentials(universe_domain="foo.bar")
assert not credentials.get_cred_info()

credentials._cred_file_path = "/path/to/file"
assert credentials.get_cred_info() == {
"credential_source": "/path/to/file",
"credential_type": "impersonated credentials",
"principal": "[email protected]",
"universe_domain": "foo.bar",
"iam_endpoint_override": "https://iamcredentials.foo.bar/v1/projects/-"
+ "/serviceAccounts/[email protected]:generateAccessToken"
}
def test_explicit_universe_domain_matching_source(self):
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(universe_domain="foo.bar", source_credentials=source_credentials)
assert credentials.universe_domain == "foo.bar"

def test_universe_domain_from_source(self):
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(source_credentials=source_credentials)
assert credentials.universe_domain == "foo.bar"

def test_explicit_universe_domain_not_matching_source(self):
with pytest.raises(exceptions.InvalidOperation) as excinfo:
self.make_credentials(universe_domain="foo.bar")

assert excinfo.match(impersonated_credentials._UNIVERSE_DOMAIN_MATCH_SOURCE_ERROR)

def test__make_copy_get_cred_info(self):
credentials = self.make_credentials()
Expand Down Expand Up @@ -409,9 +414,18 @@ def test_signer_email(self):
assert credentials.signer_email == self.TARGET_PRINCIPAL

def test_sign_endpoint(self):
credentials = self.make_credentials(universe_domain="foo.bar")
assert credentials.get_iam_sign_endpoint == "https://iamcredentials.foo.bar/v1/projects/-"
+ "/serviceAccounts/[email protected]:signBlob"
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(source_credentials=source_credentials)
assert credentials.get_iam_sign_endpoint() == "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/[email protected]:signBlob"

def test_sign_endpoint_explicit_universe_domain(self):
source_credentials = service_account.Credentials(
SIGNER, "[email protected]", TOKEN_URI, universe_domain="foo.bar"
)
credentials = self.make_credentials(universe_domain="foo.bar", source_credentials=source_credentials)
assert credentials.get_iam_sign_endpoint() == "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/[email protected]:signBlob"

def test_service_account_email(self):
credentials = self.make_credentials(target_principal=self.TARGET_PRINCIPAL)
Expand Down

0 comments on commit 93340f7

Please sign in to comment.