Skip to content

Commit

Permalink
Merge pull request #985 from materialsproject/configure_query_on_request
Browse files Browse the repository at this point in the history
capability to configure query on request
  • Loading branch information
rkingsbury authored Aug 16, 2024
2 parents bc75106 + 6dcd244 commit 0bb793c
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/maggma/api/resource/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pydantic import BaseModel
from starlette.responses import RedirectResponse

from maggma.api.query_operator import QueryOperator
from maggma.api.utils import STORE_PARAMS, api_sanitize
from maggma.utils import dynamic_import

Expand Down Expand Up @@ -106,3 +107,9 @@ def process_header(self, response: Response, request: Request):
It can use data in the upstream request to generate the header.
(https://fastapi.tiangolo.com/advanced/response-headers/#use-a-response-parameter).
"""

@abstractmethod
def configure_query_on_request(self, request: Request, query_operator: QueryOperator) -> STORE_PARAMS:
"""
This method takes in a FastAPI Request object and returns a query to be used in the store.
"""
12 changes: 10 additions & 2 deletions src/maggma/api/resource/read_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
key_fields: Optional[list[str]] = None,
hint_scheme: Optional[HintScheme] = None,
header_processor: Optional[HeaderProcessor] = None,
query_to_configure_on_request: Optional[QueryOperator] = None,
timeout: Optional[int] = None,
enable_get_by_key: bool = False,
enable_default_search: bool = True,
Expand All @@ -49,6 +50,7 @@ def __init__(
query_operators: Operators for the query language
hint_scheme: The hint scheme to use for this resource
header_processor: The header processor to use for this resource
query_to_configure_on_request: Query operator to configure on request
timeout: Time in seconds Pymongo should wait when querying MongoDB
before raising a timeout error
key_fields: List of fields to always project. Default uses SparseFieldsQuery
Expand All @@ -66,6 +68,7 @@ def __init__(
self.tags = tags or []
self.hint_scheme = hint_scheme
self.header_processor = header_processor
self.query_to_configure_on_request = query_to_configure_on_request
self.key_fields = key_fields
self.versioned = False
self.enable_get_by_key = enable_get_by_key
Expand Down Expand Up @@ -196,10 +199,16 @@ def search(**queries: dict[str, STORE_PARAMS]) -> Union[dict, Response]:
request: Request = queries.pop("request") # type: ignore
temp_response: Response = queries.pop("temp_response") # type: ignore

if self.query_to_configure_on_request is not None:
# give the key name "request", arbitrary choice, as only the value gets merged into the query
queries["groups"] = self.header_processor.configure_query_on_request(
request=request, query_operator=self.query_to_configure_on_request
)
# allowed query parameters
query_params = [
entry for _, i in enumerate(self.query_operators) for entry in signature(i.query).parameters
]

# check for overlap between allowed query parameters and request query parameters
overlap = [key for key in request.query_params if key not in query_params]
if any(overlap):
if "limit" in overlap or "skip" in overlap:
Expand All @@ -208,7 +217,6 @@ def search(**queries: dict[str, STORE_PARAMS]) -> Union[dict, Response]:
detail="'limit' and 'skip' parameters have been renamed. "
"Please update your API client to the newest version.",
)

else:
raise HTTPException(
status_code=400,
Expand Down
29 changes: 28 additions & 1 deletion tests/api/test_read_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from maggma.api.query_operator import NumericQuery, SparseFieldsQuery, StringQueryOperator
from maggma.api.resource import ReadOnlyResource
from maggma.api.resource.core import HintScheme
from maggma.api.resource.core import HeaderProcessor, HintScheme
from maggma.stores import AliasingStore, MemoryStore


Expand All @@ -32,6 +32,18 @@ class Owner(BaseModel):
total_owners = len(owners)


# Create a subclass of the header processor to prevent TypeErrors:
# Can't instantiate abstract class HeaderProcessor with abstract methods
class TestHeaderProcessor(HeaderProcessor):
def configure_query_on_request(self, request, query_operator):
# Implement the method
return {"name": "PersonAge9"}

def process_header(self, response, request):
# Implement the method
pass


@pytest.fixture()
def owner_store():
store = MemoryStore("owners", key="name")
Expand Down Expand Up @@ -134,6 +146,8 @@ def search_helper(payload, base: str = "/?", debug=True) -> Response:
NumericQuery(model=Owner),
SparseFieldsQuery(model=Owner),
],
header_processor=TestHeaderProcessor(),
query_to_configure_on_request=StringQueryOperator(model=Owner),
disable_validation=True,
)
app = FastAPI()
Expand Down Expand Up @@ -214,3 +228,16 @@ def test_resource_compound():
assert len(data) == 1
assert data[0]["name"] == "PersonAge20Weight200"
assert "weight" not in data[0]


def test_configure_query_on_request():
payload = {
"name": "PersonAge20Weight200",
"_all_fields": False,
"_fields": "name,age",
"weight_min": 199.3,
"weight_max": 201.9,
"age": 20,
}
res, data = search_helper(payload=payload, base="/?", debug=True)
assert res.status_code == 200

0 comments on commit 0bb793c

Please sign in to comment.