Skip to content

Commit

Permalink
Allow overriding of unauthorized callback.
Browse files Browse the repository at this point in the history
Related to issue pallets-eco#255.
  • Loading branch information
nfvs committed May 2, 2015
1 parent 8a62b5f commit 10fd184
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 8 deletions.
6 changes: 5 additions & 1 deletion flask_security/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def _get_state(app, datastore, **kwargs):
reset_serializer=_get_serializer(app, 'reset'),
confirm_serializer=_get_serializer(app, 'confirm'),
_context_processors={},
_send_mail_task=None
_send_mail_task=None,
_unauthorized_callback=None
))

for key, value in _default_forms.items():
Expand Down Expand Up @@ -381,6 +382,9 @@ def mail_context_processor(self, fn):
def send_mail_task(self, fn):
self._send_mail_task = fn

def unauthorized_handler(self, fn):
self._unauthorized_callback = fn


class Security(object):
"""The :class:`Security` class initializes the Flask-Security extension.
Expand Down
29 changes: 22 additions & 7 deletions flask_security/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,12 @@ def decorator(fn):
def wrapper(*args, **kwargs):
if _check_http_auth():
return fn(*args, **kwargs)
r = _security.default_http_auth_realm if callable(realm) else realm
h = {'WWW-Authenticate': 'Basic realm="%s"' % r}
return _get_unauthorized_response(headers=h)
if _security._unauthorized_callback:
return _security._unauthorized_callback()
else:
r = _security.default_http_auth_realm if callable(realm) else realm
h = {'WWW-Authenticate': 'Basic realm="%s"' % r}
return _get_unauthorized_response(headers=h)
return wrapper

if callable(realm):
Expand All @@ -112,7 +115,10 @@ def auth_token_required(fn):
def decorated(*args, **kwargs):
if _check_token():
return fn(*args, **kwargs)
return _get_unauthorized_response()
if _security._unauthorized_callback:
return _security._unauthorized_callback()
else:
return _get_unauthorized_response()
return decorated


Expand Down Expand Up @@ -145,7 +151,10 @@ def decorated_view(*args, **kwargs):
elif method == 'basic':
r = _security.default_http_auth_realm
h['WWW-Authenticate'] = 'Basic realm="%s"' % r
return _get_unauthorized_response(headers=h)
if _security._unauthorized_callback:
return _security._unauthorized_callback()
else:
return _get_unauthorized_response()
return decorated_view
return wrapper

Expand All @@ -170,7 +179,10 @@ def decorated_view(*args, **kwargs):
perms = [Permission(RoleNeed(role)) for role in roles]
for perm in perms:
if not perm.can():
return _get_unauthorized_view()
if _security._unauthorized_callback:
return _security._unauthorized_callback()
else:
return _get_unauthorized_view()
return fn(*args, **kwargs)
return decorated_view
return wrapper
Expand All @@ -196,7 +208,10 @@ def decorated_view(*args, **kwargs):
perm = Permission(*[RoleNeed(role) for role in roles])
if perm.can():
return fn(*args, **kwargs)
return _get_unauthorized_view()
if _security._unauthorized_callback:
return _security._unauthorized_callback()
else:
return _get_unauthorized_view()
return decorated_view
return wrapper

Expand Down
16 changes: 16 additions & 0 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,19 @@ def test_password_unicode_password_salt(client):
assert response.status_code == 302
response = authenticate(client, follow_redirects=True)
assert b'Hello [email protected]' in response.data


def test_set_unauthorized_handler(app, client):
@app.security.unauthorized_handler
def unauthorized():
app.unauthorized_handler_set = True
return 'unauthorized-handler-set', 401

app.unauthorized_handler_set = False

authenticate(client, "[email protected]")
response = client.get("/admin", follow_redirects=True)

assert app.unauthorized_handler_set is True
assert b'unauthorized-handler-set' in response.data
assert response.status_code == 401

0 comments on commit 10fd184

Please sign in to comment.