From dc1abd279950904397a29e7b8cfd449279649b14 Mon Sep 17 00:00:00 2001 From: Daniel McKnight Date: Tue, 29 Oct 2024 18:41:34 -0700 Subject: [PATCH] Update `client_manager` to use `mq_connector` for authentication via `neon-users-service` Update tokens to include more data, maintaining backwards-compat and adding `TokenConfig` compat. Update tokens for Klat token compat Update permissions handling to respect user configuration values Update auth request to include token_name for User database integration Add UserProfile.from_user_config for database compat. Update MQ connector to integrate with users service --- neon_hana/app/dependencies.py | 2 +- neon_hana/auth/client_manager.py | 82 ++++++++++++++++++++++--------- neon_hana/mq_service_api.py | 63 +++++++++++++++++++++++- neon_hana/mq_websocket_api.py | 2 +- neon_hana/schema/auth_requests.py | 5 +- neon_hana/schema/user_profile.py | 62 +++++++++++++++++++++++ requirements/requirements.txt | 4 +- tests/test_schema.py | 18 +++++++ 8 files changed, 210 insertions(+), 28 deletions(-) create mode 100644 tests/test_schema.py diff --git a/neon_hana/app/dependencies.py b/neon_hana/app/dependencies.py index 0c9dcf5..e6e4726 100644 --- a/neon_hana/app/dependencies.py +++ b/neon_hana/app/dependencies.py @@ -31,5 +31,5 @@ config = Configuration().get("hana") or dict() mq_connector = MQServiceManager(config) -client_manager = ClientManager(config) +client_manager = ClientManager(config, mq_connector) jwt_bearer = UserTokenAuth(client_manager) diff --git a/neon_hana/auth/client_manager.py b/neon_hana/auth/client_manager.py index 578ea13..1e89529 100644 --- a/neon_hana/auth/client_manager.py +++ b/neon_hana/auth/client_manager.py @@ -37,10 +37,12 @@ from token_throttler.storage import RuntimeStorage from neon_hana.auth.permissions import ClientPermissions +from neon_hana.mq_service_api import MQServiceManager +from neon_users_service.models import User, AccessRoles, TokenConfig class ClientManager: - def __init__(self, config: dict): + def __init__(self, config: dict, mq_connector: MQServiceManager): self.rate_limiter = TokenThrottler(cost=1, storage=RuntimeStorage()) self.authorized_clients: Dict[str, dict] = dict() @@ -58,8 +60,9 @@ def __init__(self, config: dict): self._jwt_algo = "HS256" self._connected_streams = 0 self._stream_check_lock = Lock() + self._mq_connector = mq_connector - def _create_tokens(self, encode_data: dict) -> dict: + def _create_tokens(self, encode_data: dict) -> TokenConfig: # Permissions were not included in old tokens, allow refreshing with # default permissions encode_data.setdefault("permissions", ClientPermissions().as_dict()) @@ -69,13 +72,14 @@ def _create_tokens(self, encode_data: dict) -> dict: encode_data['expire'] = time() + self._refresh_token_lifetime encode_data['access_token'] = token refresh = jwt.encode(encode_data, self._refresh_secret, self._jwt_algo) - # TODO: Store refresh token on server to allow invalidating clients - return {"username": encode_data['username'], - "client_id": encode_data['client_id'], - "permissions": encode_data['permissions'], - "access_token": token, - "refresh_token": refresh, - "expiration": token_expiration} + return TokenConfig(**{"username": encode_data['username'], + "client_id": encode_data['client_id'], + "permissions": encode_data['permissions'], + "access_token": token, + "refresh_token": refresh, + "expiration": token_expiration, + "token_name": encode_data['name'], + "refresh_expiration": encode_data['expire']}) def get_permissions(self, client_id: str) -> ClientPermissions: """ @@ -114,6 +118,7 @@ def disconnect_stream(self): def check_auth_request(self, client_id: str, username: str, password: Optional[str] = None, + token_name: Optional[str] = None, origin_ip: str = "127.0.0.1") -> dict: """ Authenticate and Authorize a new client connection with the specified @@ -121,6 +126,7 @@ def check_auth_request(self, client_id: str, username: str, @param client_id: Client ID of the connection to auth @param username: Supplied username to authenticate @param password: Supplied password to authenticate + @param token_name: Token name to add to user database @param origin_ip: Origin IP address of request @return: response tokens, permissions, and other metadata """ @@ -142,23 +148,40 @@ def check_auth_request(self, client_id: str, username: str, detail=f"Too many auth requests from: " f"{origin_ip}. Wait {wait_time}s.") - node_access = False - if username != "guest": - # TODO: Validate password here - pass - if all((self._node_username, username == self._node_username, - password == self._node_password)): - node_access = True - permissions = ClientPermissions(node=node_access) - expiration = time() + self._access_token_lifetime + # TODO: disable "guest" access? + if username == "guest": + user = User(username=username, password=password) + elif all((self._node_username, username == self._node_username, + password == self._node_password)): + user = User(username=username, password=password) + user.permissions.node = AccessRoles.USER + else: + user = self._mq_connector.get_user_profile(username, password) + username = user.username + password = user.password_hash + + # Boolean permissions allow access for any role, including `NODE`. + # Specific endpoints may enforce more granular controls/limits based on + # specific user.permissions values. + permissions = ClientPermissions( + node=user.permissions.node != AccessRoles.NONE, + assist=user.permissions.core != AccessRoles.NONE, + backend=user.permissions.diana != AccessRoles.NONE) + create_time = time() + expiration = create_time + self._access_token_lifetime encode_data = {"client_id": client_id, + "sub": username, # Added for Klat token compat. + "name": token_name, "username": username, "password": password, "permissions": permissions.as_dict(), - "expire": expiration} + "create": create_time, + "expire": expiration, + "last_refresh_timestamp": create_time} auth = self._create_tokens(encode_data) - self.authorized_clients[client_id] = auth - return auth + self._add_token_to_userdb(user, auth) + self.authorized_clients[client_id] = auth.model_dump() + return auth.model_dump() def check_refresh_request(self, access_token: str, refresh_token: str, client_id: str): @@ -185,9 +208,22 @@ def check_refresh_request(self, access_token: str, refresh_token: str, detail="Access token does not match client_id") encode_data = {k: token_data[k] for k in ("client_id", "username", "password")} - encode_data["expire"] = time() + self._access_token_lifetime + refresh_time = time() + encode_data['last_refresh_timestamp'] = refresh_time + encode_data["expire"] = refresh_time + self._access_token_lifetime new_auth = self._create_tokens(encode_data) - return new_auth + user = self._mq_connector.get_user_profile(username=token_data['username'], + password=token_data['password']) + self._add_token_to_userdb(user, new_auth) + return new_auth.model_dump() + + def _add_token_to_userdb(self, user: User, token_data: TokenConfig): + # Enforce unique `creation_timestamp` values to avoid duplicate entries + for idx, token in enumerate(user.tokens): + if token.creation_timestamp == token_data.creation_timestamp: + user.tokens.remove(token) + user.tokens.append(token_data) + self._mq_connector.update_user(user) def get_client_id(self, token: str) -> str: """ diff --git a/neon_hana/mq_service_api.py b/neon_hana/mq_service_api.py index 032e800..3dfcf93 100644 --- a/neon_hana/mq_service_api.py +++ b/neon_hana/mq_service_api.py @@ -27,13 +27,14 @@ import json from time import time -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, Union from uuid import uuid4 from fastapi import HTTPException from neon_hana.schema.node_model import NodeData from neon_hana.schema.user_profile import UserProfile from neon_mq_connector.utils.client_utils import send_mq_request +from neon_users_service.models import User class APIError(HTTPException): @@ -77,6 +78,29 @@ def _validate_api_proxy_response(response: dict, query_params: dict): code = response['status_code'] if response['status_code'] > 200 else 500 raise APIError(status_code=code, detail=response['content']) + @staticmethod + def _query_users_api(operation: str, username: Optional[str] = None, + password: Optional[str] = None, + user: Optional[User] = None) -> (bool, Union[User, int, str]): + """ + Query the users API and return a status code and either a valid User or + a string error message + @param operation: Operation to perform (create, read, update, delete) + @param username: Optional username to include + @param password: Optional password to include + @param user: Optional user object to include + @return: success bool, User object or string error message + """ + response = send_mq_request("/neon_users", + {"operation": operation, + "username": username, + "password": password, + "user": user}, + "neon_users_input") + if response.get("success"): + return True, 200, response.get("user") + return False, response.get("code", 500), response.get("error", "") + def get_session(self, node_data: NodeData) -> dict: """ Get a serialized Session object for the specified Node. @@ -89,6 +113,43 @@ def get_session(self, node_data: NodeData) -> dict: "site_id": node_data.location.site_id}) return self.sessions_by_id[session_id] + def get_user_profile(self, username: str, password: str) -> User: + """ + Get a User object for a user. This requires that a valid password be + provided to prevent arbitrary users from reading private profile info. + @param username: Valid username to get a User object for + @param password: Valid password for the input username + @returns: User object from the Users service. + """ + stat, code, err_or_user = self._query_users_api("read", + username=username, + password=password) + if not stat: + raise HTTPException(status_code=code, detail=err_or_user) + return err_or_user + + def create_user(self, user: User) -> User: + """ + Create a new user. + @param user: User object to add to the users service database + @returns: User object added to the database + """ + stat, code, err_or_user = self._query_users_api("create", user=user) + if not stat: + raise HTTPException(status_code=code, detail=err_or_user) + return err_or_user + + def update_user(self, user: User) -> User: + """ + Update an existing user in the database. + @param user: Updated user object to write + @returns: User as read from the database + """ + stat, code, err_or_user = self._query_users_api("update", user=user) + if not stat: + raise HTTPException(status_code=code, detail=err_or_user) + return err_or_user + def query_api_proxy(self, service_name: str, query_params: dict, timeout: int = 10): query_params['service'] = service_name diff --git a/neon_hana/mq_websocket_api.py b/neon_hana/mq_websocket_api.py index 6b9b96e..d9d0bae 100644 --- a/neon_hana/mq_websocket_api.py +++ b/neon_hana/mq_websocket_api.py @@ -33,7 +33,7 @@ from neon_iris.client import NeonAIClient from ovos_bus_client.message import Message from threading import RLock -from ovos_utils import LOG +from ovos_utils.log import LOG class ClientNotKnown(RuntimeError): diff --git a/neon_hana/schema/auth_requests.py b/neon_hana/schema/auth_requests.py index d02724d..c889639 100644 --- a/neon_hana/schema/auth_requests.py +++ b/neon_hana/schema/auth_requests.py @@ -24,6 +24,7 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from datetime import datetime from typing import Optional from uuid import uuid4 @@ -33,13 +34,15 @@ class AuthenticationRequest(BaseModel): username: str = "guest" password: Optional[str] = None + token_name: str = Field(default_factory=lambda: datetime.utcnow().isoformat()) client_id: str = Field(default_factory=lambda: str(uuid4())) model_config = { "json_schema_extra": { "examples": [{ "username": "guest", - "password": "password" + "password": "password", + "token_name": "My Client" }]}} diff --git a/neon_hana/schema/user_profile.py b/neon_hana/schema/user_profile.py index 91a9a05..2c8bdfb 100644 --- a/neon_hana/schema/user_profile.py +++ b/neon_hana/schema/user_profile.py @@ -24,9 +24,14 @@ # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import pytz +import datetime + from typing import Optional, List from pydantic import BaseModel +from neon_users_service.models import User + class ProfileUser(BaseModel): first_name: str = "" @@ -102,3 +107,60 @@ class UserProfile(BaseModel): location: ProfileLocation = ProfileLocation() response_mode: ProfileResponseMode = ProfileResponseMode() privacy: ProfilePrivacy = ProfilePrivacy() + + @classmethod + def from_user_config(cls, user: User): + user_config = user.neon + today = datetime.date.today() + if user_config.user.dob: + dob = user_config.user.dob + age = today.year - dob.year - ( + (today.month, today.day) < (dob.month, dob.day)) + dob = dob.strftime("%Y/%m/%d") + else: + age = "" + dob = "YYYY/MM/DD" + full_name = " ".join((n for n in (user_config.user.first_name, + user_config.user.middle_name, + user_config.user.last_name) if n)) + user = ProfileUser(about=user_config.user.about, + age=age, dob=dob, + email=user_config.user.email, + email_verified=user_config.user.email_verified, + first_name=user_config.user.first_name, + full_name=full_name, + last_name=user_config.user.last_name, + middle_name=user_config.user.middle_name, + password=user.password_hash or "", + phone=user_config.user.phone, + phone_verified=user_config.user.phone_verified, + picture=user_config.user.avatar_url, + preferred_name=user_config.user.preferred_name, + username=user.username + ) + alt_stt = [lang.split('-')[0] for lang in + user_config.language.input_languages[1:]] + secondary_tts_lang = user_config.language.output_languages[1] if ( + len(user_config.language.output_languages) > 1) else None + speech = ProfileSpeech( + alt_langs=alt_stt, + secondary_tts_gender=user_config.response_mode.tts_gender, + secondary_tts_language=secondary_tts_lang, + speed_multiplier=user_config.response_mode.tts_speed_multiplier, + stt_language=user_config.language.input_languages[0].split('-')[0], + tts_gender=user_config.response_mode.tts_gender, + tts_language=user_config.language.output_languages[0]) + units = ProfileUnits(**user_config.units.model_dump()) + utc_hours = (pytz.timezone(user_config.location.timezone or "UTC") + .utcoffset(datetime.datetime.now()).total_seconds() / 3600) + # TODO: Get city, state, country from lat/lon + location = ProfileLocation(lat=user_config.location.latitude, + lng=user_config.location.longitude, + tz=user_config.location.timezone, + utc=utc_hours) + response_mode = ProfileResponseMode(**user_config.response_mode.model_dump()) + privacy = ProfilePrivacy(**user_config.privacy.model_dump()) + + return UserProfile(location=location, privacy=privacy, + response_mode=response_mode, speech=speech, + units=units, user=user) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 0e5f1c8..b8e34bd 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -5,4 +5,6 @@ pydantic~=2.5 pyjwt~=2.8 token-throttler~=1.4 neon-mq-connector~=0.7 -ovos-config~=0.0.12 \ No newline at end of file +ovos-config~=0.0,>=0.0.12 +ovos-utils~=0.0,>=0.0.38 +neon-users-service@git+https://github.com/neongeckocom/neon-users-service@FEAT_InitialImplementation \ No newline at end of file diff --git a/tests/test_schema.py b/tests/test_schema.py new file mode 100644 index 0000000..bc29085 --- /dev/null +++ b/tests/test_schema.py @@ -0,0 +1,18 @@ +from unittest import TestCase + +from neon_hana.schema.user_profile import UserProfile + +from neon_users_service.models import User + + +class TestUserProfile(TestCase): + def test_user_profile(self): + # Test default + profile = UserProfile() + self.assertIsInstance(profile, UserProfile) + + # Test from User + default_user = User(username="test_user") + profile = UserProfile.from_user_config(default_user) + self.assertIsInstance(profile, UserProfile) + self.assertEqual(default_user.username, profile.user.username)