Skip to content
This repository has been archived by the owner on Feb 22, 2024. It is now read-only.

Add role check #118

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 36 additions & 5 deletions flask_jwt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
'JWT_AUTH_USERNAME_KEY': 'username',
'JWT_AUTH_PASSWORD_KEY': 'password',
'JWT_ALGORITHM': 'HS256',
'JWT_ROLE': 'role',
'JWT_LEEWAY': timedelta(seconds=10),
'JWT_AUTH_HEADER_PREFIX': 'JWT',
'JWT_EXPIRATION_DELTA': timedelta(seconds=300),
Expand Down Expand Up @@ -141,7 +142,21 @@ def _default_jwt_error_handler(error):
])), error.status_code, error.headers


def _jwt_required(realm):
def _force_iterable(input):
"""If role is just a string, force it to an array.
"""
try:
basestring
except NameError:
basestring = str
if isinstance(input, basestring):
return [input]
if not hasattr(input, "__iter__"):
return [input]
return input


def _jwt_required(realm, roles):
"""Does the actual work of verifying the JWT data in the current request.
This is done automatically for you by `jwt_required()` but you could call it manually.
Doing so would be useful in the context of optional JWT access in your APIs.
Expand All @@ -163,17 +178,33 @@ def _jwt_required(realm):

if identity is None:
raise JWTError('Invalid JWT', 'User does not exist')


def jwt_required(realm=None):
if roles:
try:
identity_role = getattr(identity, current_app.config['JWT_ROLE'])
except AttributeError:
try:
identity_role = identity.get(current_app.config['JWT_ROLE'])
except AttributeError:
raise JWTError('Bad Request', 'Invalid credentials')
if not identity_role:
raise JWTError('Bad Request', 'Invalid credentials')
identity_role = _force_iterable(identity_role)
roles = _force_iterable(roles)
if not identity_role or not set(roles).intersection(identity_role):
raise JWTError('Bad Request', 'Invalid credentials')


def jwt_required(realm=None, roles=None):
"""View decorator that requires a valid JWT token to be present in the request

:param realm: an optional realm
:param roles: an optional list of roles allowed,
the role is pick in JWT_ROLE field of identity
"""
def wrapper(fn):
@wraps(fn)
def decorator(*args, **kwargs):
_jwt_required(realm or current_app.config['JWT_DEFAULT_REALM'])
_jwt_required(realm or current_app.config['JWT_DEFAULT_REALM'], roles)
return fn(*args, **kwargs)
return decorator
return wrapper
Expand Down
128 changes: 127 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,20 @@
import pytest

from flask import Flask
from datetime import datetime, timedelta

import flask_jwt

logging.basicConfig(level=logging.DEBUG)


class User(object):
def __init__(self, id, username, password):
def __init__(self, id, username, password, role=None):
self.id = id
self.username = username
self.password = password
if role:
self.role = role

def __str__(self):
return "User(id='%s')" % self.id
Expand All @@ -37,6 +40,16 @@ def user():
return User(id=1, username='joe', password='pass')


@pytest.fixture(scope='function')
def user_with_role():
return User(id=2, username='jane', password='pass', role='user')


@pytest.fixture(scope='function')
def user_with_roles():
return User(id=3, username='alice', password='pass', role=['user', 'foo', 'bar'])


@pytest.fixture(scope='function')
def app(jwt, user):
app = Flask(__name__)
Expand Down Expand Up @@ -64,6 +77,119 @@ def protected():
return app


@pytest.fixture(scope='function')
def app_with_role(jwt, user, user_with_role, user_with_roles):
app = Flask(__name__)
app.debug = True
app.config['SECRET_KEY'] = 'super-secret'
users = [user, user_with_role, user_with_roles]

@jwt.authentication_handler
def authenticate(username, password):
for u in users:
if username == u.username and password == u.password:
return u
return None

@jwt.identity_handler
def load_user(payload):
for u in users:
if payload['identity'] == u.id:
return u

@jwt.jwt_payload_handler
def make_payload(identity):
iat = datetime.utcnow()
exp = iat + timedelta(seconds=300)
nbf = iat
id = getattr(identity, 'id')
try:
role = getattr(identity, 'role')
return {'exp': exp, 'iat': iat, 'nbf': nbf, 'identity': id, 'role': role}
except AttributeError:
return {'exp': exp, 'iat': iat, 'nbf': nbf, 'identity': id}

jwt.init_app(app)

@app.route('/protected')
@flask_jwt.jwt_required()
def protected():
return 'success'

@app.route('/role/protected/admin')
@flask_jwt.jwt_required(roles='admin')
def admin_protected():
return 'success'

@app.route('/role/protected/multi')
@flask_jwt.jwt_required(roles=['admin', 'user'])
def admin_user_protected():
return 'success'

@app.route('/role/protected/user')
@flask_jwt.jwt_required(roles='user')
def user_protected():
return 'success'

return app


@pytest.fixture(scope='function')
def app_with_role_trust_jwt(jwt, user, user_with_role, user_with_roles):
app = Flask(__name__)
app.debug = True
app.config['SECRET_KEY'] = 'super-secret'
app.config['JWT_ROLE'] = 'my_role'
users = [user, user_with_role, user_with_roles]

@jwt.authentication_handler
def authenticate(username, password):
for u in users:
if username == u.username and password == u.password:
return u
return None

@jwt.identity_handler
def load_user(payload):
return payload

@jwt.jwt_payload_handler
def make_payload(identity):
iat = datetime.utcnow()
exp = iat + timedelta(seconds=300)
nbf = iat
id = getattr(identity, 'id')
try:
role = getattr(identity, 'role')
return {'exp': exp, 'iat': iat, 'nbf': nbf, 'identity': id, 'my_role': role}
except AttributeError:
return {'exp': exp, 'iat': iat, 'nbf': nbf, 'identity': id}

jwt.init_app(app)

@app.route('/protected')
@flask_jwt.jwt_required()
def protected():
return 'success'

@app.route('/role/protected/user')
@flask_jwt.jwt_required(roles='user')
def user_protected():
return 'success'

@app.route('/role/protected/multi')
@flask_jwt.jwt_required(roles=['admin', 'user'])
def admin_user_protected():
return 'success'

@app.route('/role/protected/admin')
@flask_jwt.jwt_required(roles='admin')
def admin_protected():
return 'success'

return app


@pytest.fixture(scope='function')
def client(app):
return app.test_client()
143 changes: 143 additions & 0 deletions tests/test_jwt.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,3 +291,146 @@ def custom_auth_request_handler():
with app.test_client() as c:
resp, jdata = post_json(c, '/auth', {})
assert jdata == {'hello': 'world'}


def test_role_required(app_with_role, user_with_role):
with app_with_role.test_client() as c:
resp, jdata = post_json(
c, '/auth', {'username': user_with_role.username, 'password': user_with_role.password})
token = jdata['access_token']

# check if protected works with role set but not asked for this path
resp = c.get('/protected', headers={'authorization': 'JWT ' + token})
assert resp.status_code == 200
assert resp.data == b'success'

# check if protected works wit role set but not asked for this path
resp = c.get('/role/protected/user', headers={'Authorization': 'JWT ' + token})

assert resp.status_code == 200
assert resp.data == b'success'


def test_role_required_bad(app_with_role, user, user_with_role):
with app_with_role.test_client() as c:

# test bad role
resp, jdata = post_json(
c, '/auth', {'username': user_with_role.username, 'password': user_with_role.password})

token = jdata['access_token']
resp = c.get('/role/protected/admin', headers={'Authorization': 'JWT ' + token})

assert resp.status_code == 401

# test no role
resp, jdata = post_json(
c, '/auth', {'username': user.username, 'password': user.password})

token = jdata['access_token']
resp = c.get('/role/protected/admin', headers={'Authorization': 'JWT ' + token})

assert resp.status_code == 401


def test_role_required_multi(app_with_role, user_with_roles):
with app_with_role.test_client() as c:
resp, jdata = post_json(c, '/auth', {'username': user_with_roles.username,
'password': user_with_roles.password})
token = jdata['access_token']

# check if protected works with role set but not asked for this path
resp = c.get('/protected', headers={'authorization': 'JWT ' + token})
assert resp.status_code == 200
assert resp.data == b'success'

resp = c.get('/role/protected/user', headers={'Authorization': 'JWT ' + token})

assert resp.status_code == 200
assert resp.data == b'success'


def test_role_required_multi_bad(app_with_role, user_with_roles):
with app_with_role.test_client() as c:
resp, jdata = post_json(c, '/auth', {'username': user_with_roles.username,
'password': user_with_roles.password})

token = jdata['access_token']
resp = c.get('/role/protected/admin', headers={'Authorization': 'JWT ' + token})

assert resp.status_code == 401


def test_multirole_required_multi(app_with_role, user, user_with_roles):
with app_with_role.test_client() as c:
resp, jdata = post_json(c, '/auth', {'username': user_with_roles.username,
'password': user_with_roles.password})
token = jdata['access_token']

# check if protected works with role set but not asked for this path
resp = c.get('/protected', headers={'authorization': 'JWT ' + token})
assert resp.status_code == 200
assert resp.data == b'success'

resp = c.get('/role/protected/multi', headers={'Authorization': 'JWT ' + token})

assert resp.status_code == 200
assert resp.data == b'success'

# test no role
resp, jdata = post_json(
c, '/auth', {'username': user.username, 'password': user.password})

token = jdata['access_token']
resp = c.get('/role/protected/multi', headers={'Authorization': 'JWT ' + token})

assert resp.status_code == 401


def test_role_custom(app_with_role_trust_jwt, user, user_with_role, user_with_roles):
with app_with_role_trust_jwt.test_client() as c:
resp, jdata = post_json(c, '/auth', {'username': user_with_role.username,
'password': user_with_role.password})
token = jdata['access_token']

# check if protected works with role set but not asked for this path
resp = c.get('/protected', headers={'authorization': 'JWT ' + token})
assert resp.status_code == 200
assert resp.data == b'success'

# check unauthorized role protection
resp = c.get('/role/protected/admin', headers={'Authorization': 'JWT ' + token})

assert resp.status_code == 401

resp = c.get('/role/protected/multi', headers={'Authorization': 'JWT ' + token})

assert resp.status_code == 200
assert resp.data == b'success'

resp = c.get('/role/protected/user', headers={'Authorization': 'JWT ' + token})

assert resp.status_code == 200
assert resp.data == b'success'

resp, jdata = post_json(c, '/auth', {'username': user_with_roles.username,
'password': user_with_roles.password})
token = jdata['access_token']

# check if protected works with role set but not asked for this path
resp = c.get('/protected', headers={'authorization': 'JWT ' + token})
assert resp.status_code == 200
assert resp.data == b'success'

resp = c.get('/role/protected/multi', headers={'Authorization': 'JWT ' + token})

assert resp.status_code == 200
assert resp.data == b'success'
# test no role
resp, jdata = post_json(
c, '/auth', {'username': user.username, 'password': user.password})

token = jdata['access_token']
resp = c.get('/role/protected/multi', headers={'Authorization': 'JWT ' + token})

assert resp.status_code == 401