Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AzureAD] Add configuration: allowed_groups, admin_groups #466

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions oauthenticator/azuread.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from jupyterhub.auth import LocalAuthenticator
from tornado.httpclient import HTTPRequest
from traitlets import default
from traitlets import List
from traitlets import Unicode

from .oauth2 import OAuthenticator
Expand Down Expand Up @@ -50,6 +51,22 @@ def _token_url_default(self):
self.tenant_id
)

allowed_groups = List(
Unicode(),
config=True,
help="Automatically allow members of selected groups",
)

admin_groups = List(
Unicode(),
config=True,
help="Groups whose members should have Jupyterhub admin privileges",
)

@staticmethod
def check_user_in_groups(member_groups, allowed_groups):
return bool(set(member_groups) & set(allowed_groups))

async def authenticate(self, handler, data=None):
code = handler.get_argument("code")

Expand Down Expand Up @@ -94,6 +111,15 @@ async def authenticate(self, handler, data=None):
# results in a decoded JWT for the user data
auth_state['user'] = decoded

groups = self.allowed_groups + self.admin_groups
if groups:
ad_groups = decoded.get('groups')
if self.check_user_in_groups(ad_groups, groups):
userdict['admin'] = self.check_user_in_groups(
ad_groups, self.admin_groups
)
else:
userdict = None
return userdict


Expand Down
110 changes: 88 additions & 22 deletions oauthenticator/tests/test_azuread.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import re
import time
import uuid
from functools import partial
from unittest import mock

import jwt
import pytest
from pytest import fixture

from ..azuread import AzureAdOAuthenticator
from .mocks import setup_oauth_mock
Expand All @@ -19,26 +21,30 @@ def test_tenant_id_from_env():
assert aad.tenant_id == tenant_id


def user_model(tenant_id, client_id, name):
def user_model(tenant_id, client_id, name, **kwargs):
"""Return a user model"""
# model derived from https://docs.microsoft.com/en-us/azure/active-directory/develop/id-tokens#v20
now = int(time.time())

user = {
"ver": "2.0",
"iss": f"https://login.microsoftonline.com/{tenant_id}/v2.0",
"sub": "AAAAAAAAAAAAAAAAAAAAAIkzqFVrSaSaFHy782bbtaQ",
"aud": client_id,
"exp": now + 3600,
"iat": now,
"nbf": now,
"name": name,
"preferred_username": name,
"oid": str(uuid.uuid1()),
"tid": tenant_id,
"nonce": "123523",
"aio": "Df2UVXL1ix!lMCWMSOJBcFatzcGfvFGhjKv8q5g0x732dR5MB5BisvGQO7YWByjd8iQDLq!eGbIDakyp5mnOrcdqHeYSnltepQmRp6AIZ8jY",
}
user.update(kwargs)

id_token = jwt.encode(
{
"ver": "2.0",
"iss": f"https://login.microsoftonline.com/{tenant_id}/v2.0",
"sub": "AAAAAAAAAAAAAAAAAAAAAIkzqFVrSaSaFHy782bbtaQ",
"aud": client_id,
"exp": now + 3600,
"iat": now,
"nbf": now,
"name": name,
"preferred_username": name,
"oid": str(uuid.uuid1()),
"tid": tenant_id,
"nonce": "123523",
"aio": "Df2UVXL1ix!lMCWMSOJBcFatzcGfvFGhjKv8q5g0x732dR5MB5BisvGQO7YWByjd8iQDLq!eGbIDakyp5mnOrcdqHeYSnltepQmRp6AIZ8jY",
},
user,
os.urandom(5),
).decode("ascii")

Expand All @@ -48,6 +54,15 @@ def user_model(tenant_id, client_id, name):
}


def _get_authenticator(**kwargs):
return AzureAdOAuthenticator(
tenant_id=str(uuid.uuid1()),
client_id=str(uuid.uuid1()),
client_secret=str(uuid.uuid1()),
**kwargs,
)


@pytest.fixture
def azure_client(client):
setup_oauth_mock(
Expand All @@ -59,6 +74,11 @@ def azure_client(client):
return client


@fixture
def get_authenticator(azure_client, **kwargs):
return partial(_get_authenticator, http_client=azure_client)


@pytest.mark.parametrize(
'username_claim',
[
Expand All @@ -68,12 +88,8 @@ def azure_client(client):
'preferred_username',
],
)
async def test_azuread(username_claim, azure_client):
authenticator = AzureAdOAuthenticator(
tenant_id=str(uuid.uuid1()),
client_id=str(uuid.uuid1()),
client_secret=str(uuid.uuid1()),
)
async def test_azuread(get_authenticator, username_claim, azure_client):
authenticator = get_authenticator()
if username_claim:
authenticator.username_claim = username_claim

Expand All @@ -95,3 +111,53 @@ async def test_azuread(username_claim, azure_client):

name = user_info['name']
assert name == jwt_user[authenticator.username_claim]


@pytest.mark.parametrize(
'allowed_groups,admin_groups,azuread_groups,expected',
[
(
[],
['jupyterhub-admin'],
['jupyterhub-admin'],
lambda r: bool(r) and r['admin'],
),
([], ['jupyterhub-admin'], ['jupyter-admin'], lambda r: not bool(r)),
(['jupyterhub'], [], ['jupyterhub'], lambda r: bool(r) and not r['admin']),
(['jupyterhub'], [], ['jupyter'], lambda r: not bool(r)),
([], [], ['jupyterhub'], lambda r: bool(r)),
(
['jupyterhub'],
['jupyterhub-admin'],
['jupyterhub', 'jupyterhub-admin'],
lambda r: bool(r) and r['admin'],
),
(['jupyterhub'], [], [], lambda r: not bool(r)),
([], [], [], lambda r: bool(r) and r.get('admin') is None),
],
)
async def test_azuread_groups(
get_authenticator,
azure_client,
allowed_groups,
admin_groups,
azuread_groups,
expected,
):
authenticator = get_authenticator(
scope=['openid', 'profile'],
allowed_groups=allowed_groups,
admin_groups=admin_groups,
)

handler = azure_client.handler_for_user(
user_model(
tenant_id=authenticator.tenant_id,
client_id=authenticator.client_id,
name="somebody",
groups=azuread_groups,
)
)

r = await authenticator.authenticate(handler)
assert expected(r)