Skip to content

Commit

Permalink
squashme: minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
olevski committed Jan 25, 2024
1 parent a857078 commit 0468c7c
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 12 deletions.
2 changes: 1 addition & 1 deletion renku/ui/service/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@
)
from renku.ui.service.logger import service_log
from renku.ui.service.serializers.headers import JWT_TOKEN_SECRET
from renku.ui.service.utils.json_encoder import SvcJSONProvider
from renku.ui.service.utils import jwk_client
from renku.ui.service.utils.json_encoder import SvcJSONProvider
from renku.ui.service.views import error_response
from renku.ui.service.views.apispec import apispec_blueprint
from renku.ui.service.views.cache import cache_blueprint
Expand Down
21 changes: 17 additions & 4 deletions renku/ui/service/serializers/headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
import base64
import binascii
import os
from copy import deepcopy
from typing import cast

import jwt
from flask import app
from marshmallow import Schema, ValidationError, fields, post_load
from flask import current_app
from marshmallow import Schema, ValidationError, fields, post_load, pre_load
from werkzeug.utils import secure_filename

JWT_TOKEN_SECRET = os.getenv("RENKU_JWT_TOKEN_SECRET", "bW9menZ3cnh6cWpkcHVuZ3F5aWJycmJn")
Expand Down Expand Up @@ -96,9 +97,9 @@ def decode_token(token):
def decode_user(data):
"""Extract renku user from the Keycloak ID token which is a JWT."""
try:
jwk = cast(jwt.PyJWKClient, app.config["KEYCLOAK_JWK_CLIENT"])
jwk = cast(jwt.PyJWKClient, current_app.config["KEYCLOAK_JWK_CLIENT"])
key = jwk.get_signing_key_from_jwt(data)
decoded = jwt.decode(data, key=key, algorithms=["RS256"], audience="renku")
decoded = jwt.decode(data, key=key.key, algorithms=["RS256"], audience="renku")
except jwt.PyJWTError:
# NOTE: older tokens used to be signed with HS256 so use this as a backup if the validation with RS256
# above fails. We used to need HS256 because a step that is now removed was generating an ID token and
Expand All @@ -110,6 +111,18 @@ def decode_user(data):
class IdentityHeaders(Schema):
"""User identity schema."""

@pre_load
def lowercase_required_headers(self, data, **kwargs):
# NOTE: App flask headers are immutable and raise an error when modified so we copy them here
data = deepcopy(data)
if "Authorization" in data:
data["authorization"] = data["Authorization"]
if "Renku-User" in data:
data["renku-user"] = data["Renku-User"]
if "Renku-user" in data:
data["renku-user"] = data["Renku-user"]
return data

@post_load
def set_user(self, data, **kwargs):
"""Extract user object from a JWT."""
Expand Down
23 changes: 16 additions & 7 deletions renku/ui/service/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Renku service utility functions."""
import os
import urllib
from time import sleep
from typing import Any, Dict, Optional, overload

import requests
import urllib
from jwt import PyJWKClient

from renku.core.util.requests import get
from renku.ui.service.config import CACHE_PROJECTS_PATH, CACHE_UPLOADS_PATH, OIDC_URL
from renku.ui.service.errors import ProgramInternalError
from renku.ui.service.logger import service_log
from renku.core.util.requests import get


def make_project_path(user, project):
Expand Down Expand Up @@ -101,28 +102,36 @@ def oidc_discovery() -> Dict[str, Any]:
retries = 0
max_retries = 30
sleep_seconds = 2
renku_domain = os.environ.get("RENKU_DOMAIN")
if not renku_domain:
raise ProgramInternalError(
error_message="Cannot perform OIDC discovery without the renku domain expected "
"to be found in the RENKU_DOMAIN environment variable."
)
full_oidc_url = f"http://{renku_domain}{OIDC_URL}"
while True:
retries += 1
try:
res: requests.Response = get(OIDC_URL)
res: requests.Response = get(full_oidc_url)
except (requests.exceptions.HTTPError, urllib.error.HTTPError) as e:
if not retries < max_retries:
service_log.error("Failed to get OIDC discovery data after all retries - the server cannot start.")
raise e
service_log.info(
f"Failed to get OIDC discovery data from {OIDC_URL}, sleeping for {sleep_seconds} seconds and retrying"
f"Failed to get OIDC discovery data from {full_oidc_url}, "
f"sleeping for {sleep_seconds} seconds and retrying"
)
sleep(sleep_seconds)
else:
service_log.info(f"Successfully fetched OIDC discovery data from {OIDC_URL}")
service_log.info(f"Successfully fetched OIDC discovery data from {full_oidc_url}")
return res.json()


def jwk_client() -> PyJWKClient:
"""Return a JWK client for Keycloak that can be used to provide JWT keys for JWT signature validation"""
"""Return a JWK client for Keycloak that can be used to provide JWT keys for JWT signature validation."""
oidc_data = oidc_discovery()
jwks_uri = oidc_data.get("jwks_uri")
if not jwks_uri:
raise ProgramInternalError(error_message="Could not find JWK URI in the OIDC discovery data")
raise ProgramInternalError(error_message="Could not find jwks_uri in the OIDC discovery data")
jwk = PyJWKClient(jwks_uri)
return jwk

0 comments on commit 0468c7c

Please sign in to comment.