Skip to content

Commit

Permalink
Cache and run cognito client creation in executor (#649)
Browse files Browse the repository at this point in the history
* Cache and run cognito client creation in executor

* Fix import

* use lru with maxsize

* Do session as well
  • Loading branch information
ludeeus authored Jun 3, 2024
1 parent f294643 commit 61d89b0
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
72 changes: 56 additions & 16 deletions hass_nabucasa/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import asyncio
from functools import partial
from functools import lru_cache, partial
import logging
import random
from typing import TYPE_CHECKING, Any
Expand Down Expand Up @@ -73,7 +73,7 @@ def __init__(self, cloud: Cloud[_ClientT]) -> None:
"""Configure the auth api."""
self.cloud = cloud
self._refresh_task: asyncio.Task | None = None
self._session = boto3.session.Session()
self._session: boto3.Session | None = None
self._request_lock = asyncio.Lock()

cloud.iot.register_on_connect(self.on_connect)
Expand Down Expand Up @@ -113,7 +113,9 @@ async def async_register(
"""Register a new account."""
try:
async with self._request_lock:
cognito = self._cognito()
cognito = await self.cloud.run_executor(
self._create_cognito_client,
)
await self.cloud.run_executor(
partial(
cognito.register,
Expand All @@ -132,7 +134,9 @@ async def async_resend_email_confirm(self, email: str) -> None:
"""Resend email confirmation."""
try:
async with self._request_lock:
cognito = self._cognito(username=email)
cognito = await self.cloud.run_executor(
partial(self._create_cognito_client, username=email),
)
await self.cloud.run_executor(
partial(
cognito.client.resend_confirmation_code,
Expand All @@ -149,7 +153,9 @@ async def async_forgot_password(self, email: str) -> None:
"""Initialize forgotten password flow."""
try:
async with self._request_lock:
cognito = self._cognito(username=email)
cognito = await self.cloud.run_executor(
partial(self._create_cognito_client, username=email),
)
await self.cloud.run_executor(cognito.initiate_forgot_password)

except ClientError as err:
Expand All @@ -163,7 +169,9 @@ async def async_login(self, email: str, password: str) -> None:
async with self._request_lock:
assert not self.cloud.is_logged_in, "Cannot login if already logged in."

cognito = self._cognito(username=email)
cognito = await self.cloud.run_executor(
partial(self._create_cognito_client, username=email),
)

async with async_timeout.timeout(30):
await self.cloud.run_executor(
Expand Down Expand Up @@ -191,7 +199,8 @@ async def async_login(self, email: str, password: str) -> None:
async def async_check_token(self) -> None:
"""Check that the token is valid and renew if necessary."""
async with self._request_lock:
if not self._authenticated_cognito.check_token(renew=False):
cognito = await self._async_authenticated_cognito()
if not cognito.check_token(renew=False):
return

try:
Expand Down Expand Up @@ -219,7 +228,7 @@ async def _async_renew_access_token(self) -> None:
Does not consume lock.
"""
cognito = self._authenticated_cognito
cognito = await self._async_authenticated_cognito()

try:
await self.cloud.run_executor(cognito.renew_access_token)
Expand All @@ -231,20 +240,28 @@ async def _async_renew_access_token(self) -> None:
except BotoCoreError as err:
raise UnknownError from err

@property
def _authenticated_cognito(self) -> pycognito.Cognito:
async def _async_authenticated_cognito(self) -> pycognito.Cognito:
"""Return an authenticated cognito instance."""
if self.cloud.access_token is None or self.cloud.refresh_token is None:
raise Unauthenticated("No authentication found")

return self._cognito(
access_token=self.cloud.access_token,
refresh_token=self.cloud.refresh_token,
return await self.cloud.run_executor(
partial(
self._create_cognito_client,
access_token=self.cloud.access_token,
refresh_token=self.cloud.refresh_token,
),
)

def _cognito(self, **kwargs: Any) -> pycognito.Cognito:
"""Get the client credentials."""
return pycognito.Cognito(
def _create_cognito_client(self, **kwargs: Any) -> pycognito.Cognito:
"""Create a new cognito client.
NOTE: This will do I/O
"""
if self._session is None:
self._session = boto3.session.Session()

return _cached_cognito(
user_pool_id=self.cloud.user_pool_id,
client_id=self.cloud.cognito_client_id,
user_pool_region=self.cloud.region,
Expand All @@ -258,3 +275,26 @@ def _map_aws_exception(err: ClientError) -> CloudError:
"""Map AWS exception to our exceptions."""
ex = AWS_EXCEPTIONS.get(err.response["Error"]["Code"], UnknownError)
return ex(err.response["Error"]["Message"])


@lru_cache(maxsize=2)
def _cached_cognito(
user_pool_id: str,
client_id: str,
user_pool_region: str,
botocore_config: Any,
session: Any,
**kwargs: Any,
) -> pycognito.Cognito:
"""Create a cached cognito client.
NOTE: This will do I/O
"""
return pycognito.Cognito(
user_pool_id=user_pool_id,
client_id=client_id,
user_pool_region=user_pool_region,
botocore_config=botocore_config,
session=session,
**kwargs,
)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def cloud_client(cloud_mock):
@pytest.fixture
def mock_cognito():
"""Mock warrant."""
with patch("hass_nabucasa.auth.CognitoAuth._cognito") as mock_cog:
with patch("hass_nabucasa.auth.CognitoAuth._create_cognito_client") as mock_cog:
yield mock_cog()


Expand Down
2 changes: 1 addition & 1 deletion tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,4 +207,4 @@ async def test_guard_no_login_authenticated_cognito(auth_mock_kwargs: dict[str,
"""Test that not authenticated cognito login raises."""
auth = auth_api.CognitoAuth(MagicMock(**auth_mock_kwargs))
with pytest.raises(auth_api.Unauthenticated):
auth._authenticated_cognito # noqa: B018
await auth._async_authenticated_cognito()

0 comments on commit 61d89b0

Please sign in to comment.