diff --git a/requirements.txt b/requirements.txt index 83d954c597..954fd5923d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -41,7 +41,7 @@ PyMySQL==1.0.2 pymssql==2.2.8 python-jose[cryptography]==3.3.0 pyyaml==6.0.1 -redis==3.5.3 +redis==4.6.0 rich-click==1.6.1 sendgrid==6.9.7 slowapi==0.1.8 diff --git a/src/fides/api/api/v1/endpoints/privacy_experience_endpoints.py b/src/fides/api/api/v1/endpoints/privacy_experience_endpoints.py index 0a4698f0ad..24b124123f 100644 --- a/src/fides/api/api/v1/endpoints/privacy_experience_endpoints.py +++ b/src/fides/api/api/v1/endpoints/privacy_experience_endpoints.py @@ -1,14 +1,12 @@ import asyncio import uuid +from functools import lru_cache from html import escape, unescape from typing import Dict, List, Optional from fastapi import Depends, HTTPException from fastapi import Query as FastAPIQuery from fastapi import Request, Response -from fastapi_pagination import Page, Params -from fastapi_pagination import paginate as fastapi_paginate -from fastapi_pagination.bases import AbstractPage from loguru import logger from sqlalchemy.orm import Query, Session from starlette.status import ( @@ -25,7 +23,6 @@ ) from fides.api.models.privacy_notice import PrivacyNotice from fides.api.models.privacy_request import ProvidedIdentity -from fides.api.schemas.privacy_experience import PrivacyExperienceResponse from fides.api.util.api_router import APIRouter from fides.api.util.consent_util import ( PRIVACY_EXPERIENCE_ESCAPE_FIELDS, @@ -46,7 +43,12 @@ router = APIRouter(tags=["Privacy Experience"], prefix=urls.V1_URL_PREFIX) +BUST_CACHE_HEADER = "bust-endpoint-cache" +CACHE_HEADER = "X-Endpoint-Cache" +PRIVACY_EXPERIENCE_CACHE: Dict[str, Dict] = {} + +@lru_cache(maxsize=20, typed=True) def get_privacy_experience_or_error( db: Session, experience_id: str ) -> PrivacyExperience: @@ -64,6 +66,7 @@ def get_privacy_experience_or_error( return privacy_experience +@lru_cache(maxsize=20, typed=True) def _filter_experiences_by_region_or_country( db: Session, region: Optional[str], experience_query: Query ) -> Query: @@ -119,16 +122,14 @@ def _filter_experiences_by_region_or_country( return db.query(PrivacyExperience).filter(False) +# TODO: readd the fides limiter @router.get( urls.PRIVACY_EXPERIENCE, status_code=HTTP_200_OK, - response_model=Page[PrivacyExperienceResponse], ) -@fides_limiter.limit(CONFIG.security.public_request_rate_limit) async def privacy_experience_list( *, db: Session = Depends(deps.get_db), - params: Params = Depends(), show_disabled: Optional[bool] = True, region: Optional[str] = None, component: Optional[ComponentType] = None, @@ -140,7 +141,7 @@ async def privacy_experience_list( include_meta: Optional[bool] = False, request: Request, # required for rate limiting response: Response, # required for rate limiting -) -> AbstractPage[PrivacyExperience]: +) -> Dict: """ Public endpoint that returns a list of PrivacyExperience records for individual regions with relevant privacy notices or tcf contents embedded in the response. @@ -149,7 +150,6 @@ async def privacy_experience_list( notices as well. :param db: - :param params: :param show_disabled: If False, returns only enabled Experiences and Notices :param region: Return the Experiences for the given region :param component: Returns Experiences of the given component type @@ -163,7 +163,33 @@ async def privacy_experience_list( :param response: :return: """ - logger.info("Finding all Privacy Experiences with pagination params '{}'", params) + + # These are the parameters that get used to create the cache. + param_hash_list = [ + show_disabled, + region, + component, + content_required, + has_config, + fides_user_device_id, + systems_applicable, + include_gvl, + include_meta, + ] + # Create a custom hash that avoids unhashable parameters + cache_hash = "_".join([repr(x) for x in param_hash_list]) + + if request.headers.get(BUST_CACHE_HEADER): + PRIVACY_EXPERIENCE_CACHE.clear() + + if PRIVACY_EXPERIENCE_CACHE.get(cache_hash): + logger.debug("Cache HIT: {}", cache_hash) + response.headers[CACHE_HEADER] = "HIT" + return PRIVACY_EXPERIENCE_CACHE[cache_hash] + + logger.debug("Cache MISS: {}", cache_hash) + response.headers[CACHE_HEADER] = "MISS" + fides_user_provided_identity: Optional[ProvidedIdentity] = None if fides_user_device_id: try: @@ -259,7 +285,11 @@ async def privacy_experience_list( results.append(privacy_experience) - return fastapi_paginate(results, params=params) + # This is structured to look like a paginated result to minimize impact from + # the caching changes + api_result = {"items": results, "total": len(results)} + PRIVACY_EXPERIENCE_CACHE[cache_hash] = api_result + return api_result def embed_experience_details( diff --git a/src/fides/api/util/cache.py b/src/fides/api/util/cache.py index a981a2f307..b6af4870aa 100644 --- a/src/fides/api/util/cache.py +++ b/src/fides/api/util/cache.py @@ -7,7 +7,6 @@ from bson.objectid import ObjectId from loguru import logger from redis import Redis -from redis.client import Script # type: ignore from redis.exceptions import ConnectionError as ConnectionErrorFromRedis from fides.api import common_exceptions @@ -63,6 +62,7 @@ def _custom_decoder(json_dict: Dict[str, Any]) -> Dict[str, Any]: return json_dict +# pylint: disable=abstract-method class FidesopsRedis(Redis): """ An extension to Redis' python bindings to support auto expiring data input. This class @@ -95,10 +95,9 @@ def get_keys_by_prefix(self, prefix: str, chunk_size: int = 1000) -> List[str]: def delete_keys_by_prefix(self, prefix: str) -> None: """Delete all keys starting with a given prefix""" - s: Script = self.register_script( + self.register_script( f"for _,k in ipairs(redis.call('keys','{prefix}*')) do redis.call('del',k) end" - ) - s() + )() def get_values(self, keys: List[str]) -> Dict[str, Optional[Any]]: """Retrieve all values corresponding to the set of input keys and return them as a diff --git a/tests/ops/api/v1/endpoints/test_privacy_experience_endpoints.py b/tests/ops/api/v1/endpoints/test_privacy_experience_endpoints.py index 8e55fff0e2..f5e30878c3 100644 --- a/tests/ops/api/v1/endpoints/test_privacy_experience_endpoints.py +++ b/tests/ops/api/v1/endpoints/test_privacy_experience_endpoints.py @@ -1,10 +1,14 @@ from __future__ import annotations +from typing import Dict + import pytest from starlette.status import HTTP_200_OK from starlette.testclient import TestClient from fides.api.api.v1.endpoints.privacy_experience_endpoints import ( + BUST_CACHE_HEADER, + CACHE_HEADER, _filter_experiences_by_region_or_country, ) from fides.api.models.privacy_experience import ComponentType, PrivacyExperience @@ -12,11 +16,33 @@ from fides.common.api.v1.urn_registry import PRIVACY_EXPERIENCE, V1_URL_PREFIX -class TestGetPrivacyExperiences: - @pytest.fixture(scope="function") - def url(self) -> str: - return V1_URL_PREFIX + PRIVACY_EXPERIENCE +def get_cache_bust_headers() -> Dict: + return {BUST_CACHE_HEADER: "true"} + + +@pytest.fixture(scope="function") +def url() -> str: + return V1_URL_PREFIX + PRIVACY_EXPERIENCE + +class TestGetPrivacyExperiencesCaching: + def test_cache_header_hit(self, url, api_client): + """Check that the header describing cache hits/misses is working.""" + api_client.get(url) + resp = api_client.get(url) + cache_header = resp.headers.get(CACHE_HEADER) + assert cache_header + assert cache_header == "HIT" + + def test_bust_cache_header(self, url, api_client): + """Check that the header to bust the cache is working.""" + resp = api_client.get(url, headers=get_cache_bust_headers()) + cache_header = resp.headers.get(CACHE_HEADER) + assert cache_header + assert cache_header == "MISS" + + +class TestGetPrivacyExperiences: def test_get_privacy_experiences_unauthenticated(self, url, api_client): """This is a public endpoint""" resp = api_client.get(url) @@ -47,7 +73,7 @@ def test_get_privacy_experiences( privacy_notice, privacy_experience_privacy_center, ): - unescape_header = {"Unescape-Safestr": "true"} + unescape_header = {"Unescape-Safestr": "true", BUST_CACHE_HEADER: "true"} resp = api_client.get(url + "?include_gvl=True", headers=unescape_header) assert resp.status_code == 200 @@ -98,7 +124,7 @@ def test_get_experiences_unescaped( privacy_experience_privacy_center, ): # Assert not escaped without proper request header - resp = api_client.get(url) + resp = api_client.get(url, headers=get_cache_bust_headers()) resp = resp.json()["items"][0] experience_config = resp["experience_config"] assert ( @@ -120,6 +146,7 @@ def test_get_privacy_experiences_show_disabled_filter( ): resp = api_client.get( url, + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -128,6 +155,7 @@ def test_get_privacy_experiences_show_disabled_filter( resp = api_client.get( url + "?show_disabled=False", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -144,6 +172,7 @@ def test_get_privacy_experiences_show_disabled_filter( privacy_experience_overlay.unlink_experience_config(db) resp = api_client.get( url + "?show_disabled=False", + headers=get_cache_bust_headers(), ) data = resp.json() assert ( @@ -162,6 +191,7 @@ def test_get_privacy_experiences_region_filter( ): resp = api_client.get( url + "?region=us_co", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -172,6 +202,7 @@ def test_get_privacy_experiences_region_filter( resp = api_client.get( url + "?region=us_ca", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -184,6 +215,7 @@ def test_get_privacy_experiences_region_filter( resp = api_client.get( url + "?region=bad_region", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 assert resp.json()["total"] == 0 @@ -204,12 +236,14 @@ def assert_france_experience_and_notices_returned(resp): response = api_client.get( url + "?region=fr_idg", + headers=get_cache_bust_headers(), ) # There are no experiences with "fr_idg" so we fell back to searching for "fr" assert_france_experience_and_notices_returned(response) response = api_client.get( url + "?region=FR-IDG", + headers=get_cache_bust_headers(), ) # Case insensitive and hyphens also work here -" assert_france_experience_and_notices_returned(response) @@ -224,6 +258,7 @@ def test_get_privacy_experiences_components_filter( ): resp = api_client.get( url + "?component=overlay", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -236,6 +271,7 @@ def test_get_privacy_experiences_components_filter( resp = api_client.get( url + "?component=privacy_center", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -247,6 +283,7 @@ def test_get_privacy_experiences_components_filter( resp = api_client.get( url + "?component=tcf_overlay", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -258,6 +295,7 @@ def test_get_privacy_experiences_components_filter( resp = api_client.get( url + "?component=bad_type", + headers=get_cache_bust_headers(), ) assert resp.status_code == 422 @@ -270,6 +308,7 @@ def test_get_privacy_experiences_has_notices_no_notices( ): resp = api_client.get( url + "?has_notices=True", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -286,6 +325,7 @@ def test_get_privacy_experiences_has_notices_no_regions_overlap( ): resp = api_client.get( url + "?has_notices=True", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -309,6 +349,7 @@ def test_get_privacy_experiences_has_notices( ): resp = api_client.get( url + "?has_notices=True", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -409,6 +450,7 @@ def assert_expected_filtered_region_response(data): # Filter on exact match region resp = api_client.get( url + "?has_notices=True®ion=us_co", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 response_json = resp.json() @@ -418,6 +460,7 @@ def assert_expected_filtered_region_response(data): # Filter on upper case and hyphens resp = api_client.get( url + "?has_notices=True®ion=US-CO", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 assert_expected_filtered_region_response(resp.json()) @@ -442,6 +485,7 @@ def test_filter_on_systems_applicable( """For systems applicable filter, notices are only embedded if they are relevant to a system""" resp = api_client.get( url + "?region=us_co", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -466,6 +510,7 @@ def test_filter_on_systems_applicable( resp = api_client.get( url + "?region=us_co&systems_applicable=True", + headers=get_cache_bust_headers(), ) notices = resp.json()["items"][0]["privacy_notices"] assert len(notices) == 1 @@ -502,6 +547,7 @@ def test_filter_on_notices_and_region_and_show_disabled_is_false( resp = api_client.get( url + "?has_notices=True®ion=us_ca&show_disabled=False", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -528,6 +574,7 @@ def test_get_privacy_experiences_show_has_config_filter( ): resp = api_client.get( url + "?has_config=False", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -536,6 +583,7 @@ def test_get_privacy_experiences_show_has_config_filter( resp = api_client.get( url + "?has_config=True", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -552,6 +600,7 @@ def test_get_privacy_experiences_show_has_config_filter( privacy_experience_privacy_center.save(db=db) resp = api_client.get( url + "?has_config=False", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -566,6 +615,7 @@ def test_get_privacy_experiences_bad_fides_user_device_id_filter( ): resp = api_client.get( url + "?fides_user_device_id=does_not_exist", + headers=get_cache_bust_headers(), ) assert resp.status_code == 422 assert resp.json()["detail"] == "Invalid fides user device id format" @@ -583,6 +633,7 @@ def test_get_privacy_experiences_nonexistent_fides_user_device_id_filter( ): resp = api_client.get( url + "?cd685ccd-0960-4dc1-b9ca-7e810ebc5c1b", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json() @@ -615,6 +666,7 @@ def test_get_privacy_experiences_fides_user_device_id_filter( ): resp = api_client.get( url + "?fides_user_device_id=051b219f-20e4-45df-82f7-5eb68a00889f", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json()["items"][0] @@ -639,6 +691,7 @@ def test_get_privacy_experiences_fides_user_device_id_filter( assert privacy_notice_us_ca_provide.description == "new_description" resp = api_client.get( url + "?fides_user_device_id=051b219f-20e4-45df-82f7-5eb68a00889f", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 data = resp.json()["items"][0] @@ -668,6 +721,7 @@ def test_tcf_not_enabled( ): resp = api_client.get( url + "?region=fr&component=overlay&include_gvl=True&include_meta=True", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 assert len(resp.json()["items"]) == 1 @@ -705,6 +759,7 @@ def test_tcf_enabled_but_no_relevant_systems( ): resp = api_client.get( url + "?region=fr&component=overlay&include_gvl=True&include_meta=True", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 assert len(resp.json()["items"]) == 1 @@ -731,6 +786,7 @@ def test_tcf_enabled_but_no_relevant_systems( # Has notices = True flag will keep this experience from appearing altogether resp = api_client.get( url + "?region=fr&has_notices=True", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 assert len(resp.json()["items"]) == 0 @@ -756,6 +812,7 @@ def test_tcf_enabled_with_overlapping_vendors( resp = api_client.get( url + "?region=fr&component=overlay&fides_user_device_id=051b219f-20e4-45df-82f7-5eb68a00889f&has_notices=True&include_gvl=True&include_meta=True", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 assert len(resp.json()["items"]) == 1 @@ -851,6 +908,7 @@ def test_tcf_enabled_with_overlapping_systems( resp = api_client.get( url + "?region=fr&component=overlay&fides_user_device_id=051b219f-20e4-45df-82f7-5eb68a00889f&has_notices=True", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 assert len(resp.json()["items"]) == 1 @@ -914,6 +972,7 @@ def test_tcf_enabled_with_legitimate_interest_purpose( resp = api_client.get( url + "?region=fr&component=overlay&fides_user_device_id=051b219f-20e4-45df-82f7-5eb68a00889f&has_notices=True", + headers=get_cache_bust_headers(), ) assert resp.status_code == 200 assert len(resp.json()["items"]) == 1