From 10fd1844d8ace715673cf7c78f1365b1f1bb7307 Mon Sep 17 00:00:00 2001 From: Nuno Santos Date: Wed, 28 May 2014 18:50:31 +0200 Subject: [PATCH] Allow overriding of unauthorized callback. Related to issue #255. --- flask_security/core.py | 6 +++++- flask_security/decorators.py | 29 ++++++++++++++++++++++------- tests/test_misc.py | 16 ++++++++++++++++ 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/flask_security/core.py b/flask_security/core.py index 188c0cb5..02761de8 100644 --- a/flask_security/core.py +++ b/flask_security/core.py @@ -273,7 +273,8 @@ def _get_state(app, datastore, **kwargs): reset_serializer=_get_serializer(app, 'reset'), confirm_serializer=_get_serializer(app, 'confirm'), _context_processors={}, - _send_mail_task=None + _send_mail_task=None, + _unauthorized_callback=None )) for key, value in _default_forms.items(): @@ -381,6 +382,9 @@ def mail_context_processor(self, fn): def send_mail_task(self, fn): self._send_mail_task = fn + def unauthorized_handler(self, fn): + self._unauthorized_callback = fn + class Security(object): """The :class:`Security` class initializes the Flask-Security extension. diff --git a/flask_security/decorators.py b/flask_security/decorators.py index da95e02d..356b5695 100644 --- a/flask_security/decorators.py +++ b/flask_security/decorators.py @@ -90,9 +90,12 @@ def decorator(fn): def wrapper(*args, **kwargs): if _check_http_auth(): return fn(*args, **kwargs) - r = _security.default_http_auth_realm if callable(realm) else realm - h = {'WWW-Authenticate': 'Basic realm="%s"' % r} - return _get_unauthorized_response(headers=h) + if _security._unauthorized_callback: + return _security._unauthorized_callback() + else: + r = _security.default_http_auth_realm if callable(realm) else realm + h = {'WWW-Authenticate': 'Basic realm="%s"' % r} + return _get_unauthorized_response(headers=h) return wrapper if callable(realm): @@ -112,7 +115,10 @@ def auth_token_required(fn): def decorated(*args, **kwargs): if _check_token(): return fn(*args, **kwargs) - return _get_unauthorized_response() + if _security._unauthorized_callback: + return _security._unauthorized_callback() + else: + return _get_unauthorized_response() return decorated @@ -145,7 +151,10 @@ def decorated_view(*args, **kwargs): elif method == 'basic': r = _security.default_http_auth_realm h['WWW-Authenticate'] = 'Basic realm="%s"' % r - return _get_unauthorized_response(headers=h) + if _security._unauthorized_callback: + return _security._unauthorized_callback() + else: + return _get_unauthorized_response() return decorated_view return wrapper @@ -170,7 +179,10 @@ def decorated_view(*args, **kwargs): perms = [Permission(RoleNeed(role)) for role in roles] for perm in perms: if not perm.can(): - return _get_unauthorized_view() + if _security._unauthorized_callback: + return _security._unauthorized_callback() + else: + return _get_unauthorized_view() return fn(*args, **kwargs) return decorated_view return wrapper @@ -196,7 +208,10 @@ def decorated_view(*args, **kwargs): perm = Permission(*[RoleNeed(role) for role in roles]) if perm.can(): return fn(*args, **kwargs) - return _get_unauthorized_view() + if _security._unauthorized_callback: + return _security._unauthorized_callback() + else: + return _get_unauthorized_view() return decorated_view return wrapper diff --git a/tests/test_misc.py b/tests/test_misc.py index 30b0fd04..d9314131 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -192,3 +192,19 @@ def test_password_unicode_password_salt(client): assert response.status_code == 302 response = authenticate(client, follow_redirects=True) assert b'Hello matt@lp.com' in response.data + + +def test_set_unauthorized_handler(app, client): + @app.security.unauthorized_handler + def unauthorized(): + app.unauthorized_handler_set = True + return 'unauthorized-handler-set', 401 + + app.unauthorized_handler_set = False + + authenticate(client, "joe@lp.com") + response = client.get("/admin", follow_redirects=True) + + assert app.unauthorized_handler_set is True + assert b'unauthorized-handler-set' in response.data + assert response.status_code == 401