diff --git a/src/maggma/api/resource/core.py b/src/maggma/api/resource/core.py index 636c76053..4ffef2122 100644 --- a/src/maggma/api/resource/core.py +++ b/src/maggma/api/resource/core.py @@ -6,6 +6,7 @@ from pydantic import BaseModel from starlette.responses import RedirectResponse +from maggma.api.query_operator import QueryOperator from maggma.api.utils import STORE_PARAMS, api_sanitize from maggma.utils import dynamic_import @@ -106,3 +107,9 @@ def process_header(self, response: Response, request: Request): It can use data in the upstream request to generate the header. (https://fastapi.tiangolo.com/advanced/response-headers/#use-a-response-parameter). """ + + @abstractmethod + def configure_query_on_request(self, request: Request, query_operator: QueryOperator) -> STORE_PARAMS: + """ + This method takes in a FastAPI Request object and returns a query to be used in the store. + """ diff --git a/src/maggma/api/resource/read_resource.py b/src/maggma/api/resource/read_resource.py index 1bf3aa5ee..60d97f4a5 100644 --- a/src/maggma/api/resource/read_resource.py +++ b/src/maggma/api/resource/read_resource.py @@ -33,6 +33,7 @@ def __init__( key_fields: Optional[list[str]] = None, hint_scheme: Optional[HintScheme] = None, header_processor: Optional[HeaderProcessor] = None, + query_to_configure_on_request: Optional[QueryOperator] = None, timeout: Optional[int] = None, enable_get_by_key: bool = False, enable_default_search: bool = True, @@ -49,6 +50,7 @@ def __init__( query_operators: Operators for the query language hint_scheme: The hint scheme to use for this resource header_processor: The header processor to use for this resource + query_to_configure_on_request: Query operator to configure on request timeout: Time in seconds Pymongo should wait when querying MongoDB before raising a timeout error key_fields: List of fields to always project. Default uses SparseFieldsQuery @@ -66,6 +68,7 @@ def __init__( self.tags = tags or [] self.hint_scheme = hint_scheme self.header_processor = header_processor + self.query_to_configure_on_request = query_to_configure_on_request self.key_fields = key_fields self.versioned = False self.enable_get_by_key = enable_get_by_key @@ -196,10 +199,16 @@ def search(**queries: dict[str, STORE_PARAMS]) -> Union[dict, Response]: request: Request = queries.pop("request") # type: ignore temp_response: Response = queries.pop("temp_response") # type: ignore + if self.query_to_configure_on_request is not None: + # give the key name "request", arbitrary choice, as only the value gets merged into the query + queries["groups"] = self.header_processor.configure_query_on_request( + request=request, query_operator=self.query_to_configure_on_request + ) + # allowed query parameters query_params = [ entry for _, i in enumerate(self.query_operators) for entry in signature(i.query).parameters ] - + # check for overlap between allowed query parameters and request query parameters overlap = [key for key in request.query_params if key not in query_params] if any(overlap): if "limit" in overlap or "skip" in overlap: @@ -208,7 +217,6 @@ def search(**queries: dict[str, STORE_PARAMS]) -> Union[dict, Response]: detail="'limit' and 'skip' parameters have been renamed. " "Please update your API client to the newest version.", ) - else: raise HTTPException( status_code=400, diff --git a/tests/api/test_read_resource.py b/tests/api/test_read_resource.py index 79ca5cf64..0adc2dec8 100644 --- a/tests/api/test_read_resource.py +++ b/tests/api/test_read_resource.py @@ -11,7 +11,7 @@ from maggma.api.query_operator import NumericQuery, SparseFieldsQuery, StringQueryOperator from maggma.api.resource import ReadOnlyResource -from maggma.api.resource.core import HintScheme +from maggma.api.resource.core import HeaderProcessor, HintScheme from maggma.stores import AliasingStore, MemoryStore @@ -32,6 +32,18 @@ class Owner(BaseModel): total_owners = len(owners) +# Create a subclass of the header processor to prevent TypeErrors: +# Can't instantiate abstract class HeaderProcessor with abstract methods +class TestHeaderProcessor(HeaderProcessor): + def configure_query_on_request(self, request, query_operator): + # Implement the method + return {"name": "PersonAge9"} + + def process_header(self, response, request): + # Implement the method + pass + + @pytest.fixture() def owner_store(): store = MemoryStore("owners", key="name") @@ -134,6 +146,8 @@ def search_helper(payload, base: str = "/?", debug=True) -> Response: NumericQuery(model=Owner), SparseFieldsQuery(model=Owner), ], + header_processor=TestHeaderProcessor(), + query_to_configure_on_request=StringQueryOperator(model=Owner), disable_validation=True, ) app = FastAPI() @@ -214,3 +228,16 @@ def test_resource_compound(): assert len(data) == 1 assert data[0]["name"] == "PersonAge20Weight200" assert "weight" not in data[0] + + +def test_configure_query_on_request(): + payload = { + "name": "PersonAge20Weight200", + "_all_fields": False, + "_fields": "name,age", + "weight_min": 199.3, + "weight_max": 201.9, + "age": 20, + } + res, data = search_helper(payload=payload, base="/?", debug=True) + assert res.status_code == 200