Skip to content

Commit

Permalink
Allow selective disabling of blocklist check (#501)
Browse files Browse the repository at this point in the history
* Allow selective disabling of blocklist check

Fixes #499

* Update the documentation

---------

Co-authored-by: Indrajeet Khandekar <[email protected]>
  • Loading branch information
indrajeet307 and Indrajeet Khandekar authored May 26, 2023
1 parent 0803dc5 commit c77d756
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 3 deletions.
30 changes: 27 additions & 3 deletions flask_jwt_extended/view_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def verify_jwt_in_request(
refresh: bool = False,
locations: Optional[LocationType] = None,
verify_type: bool = True,
skip_revocation_check: bool = False,
) -> Optional[Tuple[dict, dict]]:
"""
Verify that a valid JWT is present in the request, unless ``optional=True`` in
Expand Down Expand Up @@ -76,6 +77,14 @@ def verify_jwt_in_request(
to the ``refresh`` argument. If ``False``, type will not be checked and both
access and refresh tokens will be accepted.
:param skip_revocation_check:
If ``True``, revocation status of the token will be *not* checked. If ``False``,
revocation status of the token will be checked.
:param skip_revocation_check:
If ``True``, revocation status of the token will be *not* checked. If ``False``,
revocation status of the token will be checked.
:return:
A tuple containing the jwt_header and the jwt_data if a valid JWT is
present in the request. If ``optional=True`` and no JWT is in the request,
Expand All @@ -87,7 +96,11 @@ def verify_jwt_in_request(

try:
jwt_data, jwt_header, jwt_location = _decode_jwt_from_request(
locations, fresh, refresh=refresh, verify_type=verify_type
locations,
fresh,
refresh=refresh,
verify_type=verify_type,
skip_revocation_check=skip_revocation_check,
)

except NoAuthorizationError:
Expand Down Expand Up @@ -115,6 +128,7 @@ def jwt_required(
refresh: bool = False,
locations: Optional[LocationType] = None,
verify_type: bool = True,
skip_revocation_check: bool = False,
) -> Any:
"""
A decorator to protect a Flask endpoint with JSON Web Tokens.
Expand Down Expand Up @@ -145,12 +159,18 @@ def jwt_required(
If ``True``, the token type (access or refresh) will be checked according
to the ``refresh`` argument. If ``False``, type will not be checked and both
access and refresh tokens will be accepted.
:param skip_revocation_check:
If ``True``, revocation status of the token will be *not* checked. If ``False``,
revocation status of the token will be checked.
"""

def wrapper(fn):
@wraps(fn)
def decorator(*args, **kwargs):
verify_jwt_in_request(optional, fresh, refresh, locations, verify_type)
verify_jwt_in_request(
optional, fresh, refresh, locations, verify_type, skip_revocation_check
)
return current_app.ensure_sync(fn)(*args, **kwargs)

return decorator
Expand Down Expand Up @@ -284,6 +304,7 @@ def _decode_jwt_from_request(
fresh: bool,
refresh: bool = False,
verify_type: bool = True,
skip_revocation_check: bool = False,
) -> Tuple[dict, dict, str]:
# Figure out what locations to look for the JWT in this request
if isinstance(locations, str):
Expand Down Expand Up @@ -346,7 +367,10 @@ def _decode_jwt_from_request(

if fresh:
_verify_token_is_fresh(jwt_header, decoded_token)
verify_token_not_blocklisted(jwt_header, decoded_token)

if not skip_revocation_check:
verify_token_not_blocklisted(jwt_header, decoded_token)

custom_verification_for_token(jwt_header, decoded_token)

return decoded_token, jwt_header, jwt_location
52 changes: 52 additions & 0 deletions tests/test_blocklist.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ def app():
def access_protected():
return jsonify(foo="bar")

@app.route("/protected_skip_blocklist", methods=["GET"])
@jwt_required(verify_type=False, skip_revocation_check=True)
def access_protected_skip_blocklist():
return jsonify(foo="bar")

@app.route("/protected_noskip_blocklist", methods=["GET"])
@jwt_required(verify_type=False)
def access_protected_no_skip_blocklist():
return jsonify(foo="bar")

@app.route("/refresh_protected", methods=["GET"])
@jwt_required(refresh=True)
def refresh_protected():
Expand All @@ -29,6 +39,48 @@ def refresh_protected():
return app


@pytest.mark.parametrize("blocklist_type", [["access"], ["refresh", "access"]])
def test_blocklisted_access_token_revocation_skip(app, blocklist_type):
jwt = get_jwt_manager(app)

@jwt.token_in_blocklist_loader
def check_blocklisted(jwt_header, jwt_data):
assert jwt_header["alg"] == "HS256"
assert jwt_data["sub"] == "username"
return True

with app.test_request_context():
access_token = create_access_token("username")

test_client = app.test_client()
response = test_client.get(
"/protected_skip_blocklist", headers=make_headers(access_token)
)
assert response.get_json() == {"foo": "bar"}
assert response.status_code == 200


@pytest.mark.parametrize("blocklist_type", [["access"], ["refresh", "access"]])
def test_blocklisted_access_token_revocation_no_skip(app, blocklist_type):
jwt = get_jwt_manager(app)

@jwt.token_in_blocklist_loader
def check_blocklisted(jwt_header, jwt_data):
assert jwt_header["alg"] == "HS256"
assert jwt_data["sub"] == "username"
return True

with app.test_request_context():
access_token = create_access_token("username")

test_client = app.test_client()
response = test_client.get(
"/protected_noskip_blocklist", headers=make_headers(access_token)
)
assert response.get_json() == {"msg": "Token has been revoked"}
assert response.status_code == 401


@pytest.mark.parametrize("blocklist_type", [["access"], ["refresh", "access"]])
def test_non_blocklisted_access_token(app, blocklist_type):
jwt = get_jwt_manager(app)
Expand Down

0 comments on commit c77d756

Please sign in to comment.