Skip to content

Commit

Permalink
Refactor authentication providers to use inheritance (#856)
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph authored Jan 27, 2025
1 parent 5a48730 commit 444515c
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 74 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Write the date in place of the "Unreleased" in the case a new version is release

- Make depedencies shared by client and server into core dependencies.
- Use schemas for describing server configuration on the client side too.
- Refactored Authentication providers to make use of inheritance, adjusted
mode in the `AboutAuthenticationProvider` schema to be `internal`|`external`.

## v0.1.0-b16 (2024-01-23)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/authentication.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ $ http :8000/api/v1/ | jq .authentication
"providers": [
{
"provider": "toy",
"mode": "password",
"mode": "internal",
"links": {
"auth_endpoint": "http://localhost:8000/api/v1/auth/provider/toy/token"
},
Expand Down
6 changes: 3 additions & 3 deletions example_configs/external_service/custom.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import numpy

from tiled.adapters.array import ArrayAdapter
from tiled.authenticators import Mode, UserSessionState
from tiled.authenticators import UserSessionState
from tiled.server.protocols import InternalAuthenticator
from tiled.structures.core import StructureFamily


class Authenticator:
class Authenticator(InternalAuthenticator):
"This accepts any password and stashes it in session state as 'token'."
mode = Mode.password

async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {"token": password})
Expand Down
80 changes: 41 additions & 39 deletions tiled/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,47 +5,47 @@
import re
import secrets
from collections.abc import Iterable
from typing import Any, cast
from typing import Any, Mapping, Optional, cast

import httpx
from fastapi import APIRouter, Request
from jose import JWTError, jwt
from pydantic import Secret
from starlette.responses import RedirectResponse

from .server.authentication import Mode
from .server.protocols import UserSessionState
from .server.protocols import (
ExternalAuthenticator,
InternalAuthenticator,
UserSessionState,
)
from .server.utils import get_root_url
from .utils import modules_available

logger = logging.getLogger(__name__)


class DummyAuthenticator:
class DummyAuthenticator(InternalAuthenticator):
"""
For test and demo purposes only!
Accept any username and any password.
"""

mode = Mode.password

def __init__(self, confirmation_message=""):
def __init__(self, confirmation_message: str = ""):
self.confirmation_message = confirmation_message

async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {})


class DictionaryAuthenticator:
class DictionaryAuthenticator(InternalAuthenticator):
"""
For test and demo purposes only!
Check passwords from a dictionary of usernames mapped to passwords.
"""

mode = Mode.password
configuration_schema = """
$schema": http://json-schema.org/draft-07/schema#
type: object
Expand All @@ -61,11 +61,15 @@ class DictionaryAuthenticator:
description: May be displayed by client after successful login.
"""

def __init__(self, users_to_passwords, confirmation_message=""):
def __init__(
self, users_to_passwords: Mapping[str, str], confirmation_message: str = ""
):
self._users_to_passwords = users_to_passwords
self.confirmation_message = confirmation_message

async def authenticate(self, username: str, password: str) -> UserSessionState:
async def authenticate(
self, username: str, password: str
) -> Optional[UserSessionState]:
true_password = self._users_to_passwords.get(username)
if not true_password:
# Username is not valid.
Expand All @@ -74,8 +78,7 @@ async def authenticate(self, username: str, password: str) -> UserSessionState:
return UserSessionState(username, {})


class PAMAuthenticator:
mode = Mode.password
class PAMAuthenticator(InternalAuthenticator):
configuration_schema = """
$schema": http://json-schema.org/draft-07/schema#
type: object
Expand All @@ -89,7 +92,7 @@ class PAMAuthenticator:
description: May be displayed by client after successful login.
"""

def __init__(self, service="login", confirmation_message=""):
def __init__(self, service: str = "login", confirmation_message: str = ""):
if not modules_available("pamela"):
raise ModuleNotFoundError(
"This PAMAuthenticator requires the module 'pamela' to be installed."
Expand All @@ -98,20 +101,20 @@ def __init__(self, service="login", confirmation_message=""):
self.confirmation_message = confirmation_message
# TODO Try to open a PAM session.

async def authenticate(self, username: str, password: str) -> UserSessionState:
async def authenticate(
self, username: str, password: str
) -> Optional[UserSessionState]:
import pamela

try:
pamela.authenticate(username, password, service=self.service)
return UserSessionState(username, {})
except pamela.PAMError:
# Authentication failed.
return
else:
return UserSessionState(username, {})


class OIDCAuthenticator:
mode = Mode.external
class OIDCAuthenticator(ExternalAuthenticator):
configuration_schema = """
$schema": http://json-schema.org/draft-07/schema#
type: object
Expand Down Expand Up @@ -178,7 +181,7 @@ def authorization_endpoint(self) -> httpx.URL:
cast(str, self._config_from_oidc_url.get("authorization_endpoint"))
)

async def authenticate(self, request: Request) -> UserSessionState:
async def authenticate(self, request: Request) -> Optional[UserSessionState]:
code = request.query_params["code"]
# A proxy in the middle may make the request into something like
# 'http://localhost:8000/...' so we fix the first part but keep
Expand Down Expand Up @@ -216,11 +219,13 @@ async def authenticate(self, request: Request) -> UserSessionState:
return UserSessionState(verified_body["sub"], {})


class KeyNotFoundError(Exception):
pass


async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect_uri):
async def exchange_code(
token_uri: str,
auth_code: str,
client_id: str,
client_secret: str,
redirect_uri: str,
) -> httpx.Response:
"""Method that talks to an IdP to exchange a code for an access_token and/or id_token
Args:
token_url ([type]): [description]
Expand All @@ -241,14 +246,12 @@ async def exchange_code(token_uri, auth_code, client_id, client_secret, redirect
return response


class SAMLAuthenticator:
mode = Mode.external

class SAMLAuthenticator(ExternalAuthenticator):
def __init__(
self,
saml_settings, # See EXAMPLE_SAML_SETTINGS below.
attribute_name, # which SAML attribute to use as 'id' for Idenity
confirmation_message="",
attribute_name: str, # which SAML attribute to use as 'id' for Idenity
confirmation_message: str = "",
):
self.saml_settings = saml_settings
self.attribute_name = attribute_name
Expand All @@ -268,7 +271,7 @@ def __init__(
from onelogin.saml2.auth import OneLogin_Saml2_Auth

@router.get("/login")
async def saml_login(request: Request):
async def saml_login(request: Request) -> RedirectResponse:
req = await prepare_saml_from_fastapi_request(request)
auth = OneLogin_Saml2_Auth(req, self.saml_settings)
# saml_settings = auth.get_settings()
Expand All @@ -279,12 +282,11 @@ async def saml_login(request: Request):
# else:
# print("Error found on Metadata: %s" % (', '.join(errors)))
callback_url = auth.login()
response = RedirectResponse(url=callback_url)
return response
return RedirectResponse(url=callback_url)

self.include_routers = [router]

async def authenticate(self, request) -> UserSessionState:
async def authenticate(self, request: Request) -> Optional[UserSessionState]:
if not modules_available("onelogin"):
raise ModuleNotFoundError(
"This SAMLAuthenticator requires the module 'oneline' to be installed."
Expand All @@ -310,7 +312,7 @@ async def authenticate(self, request) -> UserSessionState:
return None


async def prepare_saml_from_fastapi_request(request, debug=False):
async def prepare_saml_from_fastapi_request(request: Request) -> Mapping[str, str]:
form_data = await request.form()
rv = {
"http_host": request.client.host,
Expand All @@ -336,7 +338,7 @@ async def prepare_saml_from_fastapi_request(request, debug=False):
return rv


class LDAPAuthenticator:
class LDAPAuthenticator(InternalAuthenticator):
"""
The authenticator code is based on https://github.com/jupyterhub/ldapauthenticator
The parameter ``use_tls`` was added for convenience of testing.
Expand Down Expand Up @@ -519,8 +521,6 @@ class LDAPAuthenticator:
id: user02
"""

mode = Mode.password

def __init__(
self,
server_address,
Expand Down Expand Up @@ -733,7 +733,9 @@ async def get_user_attributes(self, conn, userdn):
attrs = conn.entries[0].entry_attributes_as_dict
return attrs

async def authenticate(self, username: str, password: str) -> UserSessionState:
async def authenticate(
self, username: str, password: str
) -> Optional[UserSessionState]:
import ldap3

username_saved = username # Save the user name passed as a parameter
Expand Down
2 changes: 1 addition & 1 deletion tiled/client/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def prompt_for_credentials(http_client, providers: List[AboutAuthenticationProvi
auth_endpoint = spec.links["auth_endpoint"]
provider = spec.provider
mode = spec.mode
if mode == "password":
if mode == "internal":
# Prompt for username, password at terminal.
username = username_input()
PASSWORD_ATTEMPTS = 3
Expand Down
2 changes: 1 addition & 1 deletion tiled/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class AboutAuthenticationProvider(BaseModel):
provider: str
mode: Literal["password", "external"]
mode: Literal["internal", "external"]
links: Dict[str, str]
confirmation_message: Optional[str] = None

Expand Down
10 changes: 5 additions & 5 deletions tiled/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
HTTP_500_INTERNAL_SERVER_ERROR,
)

from ..authenticators import Mode
from tiled.server.protocols import ExternalAuthenticator, InternalAuthenticator

from ..config import construct_build_app_kwargs
from ..media_type_registration import (
compression_registry as default_compression_registry,
Expand Down Expand Up @@ -384,12 +385,11 @@ async def unhandled_exception_handler(
for spec in authentication["providers"]:
provider = spec["provider"]
authenticator = spec["authenticator"]
mode = authenticator.mode
if mode == Mode.password:
if isinstance(authenticator, InternalAuthenticator):
authentication_router.post(f"/provider/{provider}/token")(
build_handle_credentials_route(authenticator, provider)
)
elif mode == Mode.external:
elif isinstance(authenticator, ExternalAuthenticator):
# Client starts here to create a PendingSession.
authentication_router.post(f"/provider/{provider}/authorize")(
build_device_code_authorize_route(authenticator, provider)
Expand All @@ -414,7 +414,7 @@ async def unhandled_exception_handler(
# build_auth_code_route(authenticator, provider)
# )
else:
raise ValueError(f"unknown authentication mode {mode}")
raise ValueError(f"unknown authenticator type {type(authenticator)}")
for custom_router in getattr(authenticator, "include_routers", []):
authentication_router.include_router(
custom_router, prefix=f"/provider/{provider}"
Expand Down
12 changes: 2 additions & 10 deletions tiled/server/authentication.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import enum
import hashlib
import secrets
import uuid as uuid_module
Expand Down Expand Up @@ -62,7 +61,7 @@
from ..utils import SHARE_TILED_PATH, SpecialUsers
from . import schemas
from .core import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE, json_or_msgpack
from .protocols import UsernamePasswordAuthenticator, UserSessionState
from .protocols import InternalAuthenticator, UserSessionState
from .settings import get_settings
from .utils import API_KEY_COOKIE_NAME, get_authenticators, get_base_url

Expand All @@ -86,11 +85,6 @@ def utcnow():
return datetime.now(timezone.utc).replace(microsecond=0)


class Mode(enum.Enum):
password = "password"
external = "external"


class Token(BaseModel):
access_token: str
token_type: str
Expand Down Expand Up @@ -710,9 +704,7 @@ async def route(
return route


def build_handle_credentials_route(
authenticator: UsernamePasswordAuthenticator, provider
):
def build_handle_credentials_route(authenticator: InternalAuthenticator, provider):
"Register a handle_credentials route function for this Authenticator."

async def route(
Expand Down
17 changes: 9 additions & 8 deletions tiled/server/protocols.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
from abc import ABC
from dataclasses import dataclass
from typing import Protocol
from typing import Optional

from fastapi import Request


@dataclass
class UserSessionState:
"""Data transfer class to communicate custom session state infromation."""
"""Data transfer class to communicate custom session state information."""

user_name: str
state: dict = None


class UsernamePasswordAuthenticator(Protocol):
def authenticate(self, username: str, password: str) -> UserSessionState:
pass
class InternalAuthenticator(ABC):
def authenticate(self, username: str, password: str) -> Optional[UserSessionState]:
raise NotImplementedError


class Authenticator(Protocol):
def authenticate(self, request: Request) -> UserSessionState:
pass
class ExternalAuthenticator(ABC):
def authenticate(self, request: Request) -> Optional[UserSessionState]:
raise NotImplementedError
Loading

0 comments on commit 444515c

Please sign in to comment.