Skip to content

Commit

Permalink
test: implement file upload interfaces for flask test client
Browse files Browse the repository at this point in the history
  • Loading branch information
jkglasbrenner committed Jan 9, 2025
1 parent 7168e33 commit 8bac475
Showing 1 changed file with 87 additions and 1 deletion.
88 changes: 87 additions & 1 deletion tests/unit/restapi/lib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
import structlog
from flask.testing import FlaskClient
from structlog.stdlib import BoundLogger
from werkzeug.datastructures import FileStorage
from werkzeug.test import TestResponse

from dioptra.client.base import (
DioptraClientError,
DioptraFile,
DioptraRequestProtocol,
DioptraResponseProtocol,
DioptraSession,
IllegalArgumentError,
StatusCodeError,
)
from dioptra.restapi.routes import V1_ROOT
Expand Down Expand Up @@ -141,6 +144,46 @@ def is_2xx(status_code: int) -> bool:
return status_code >= HTTPStatus.OK and status_code < HTTPStatus.MULTIPLE_CHOICES


def format_file_for_request(file_: DioptraFile) -> FileStorage:
"""Format the DioptraFile object into a FlaskClient-compatible data structure.
Returns:
The file encoded as a Werkzeug FileStorage object.
"""
if file_.content_type is None:
return FileStorage(stream=file_.stream, filename=file_.filename)

return FileStorage(
stream=file_.stream, filename=file_.filename, content_type=file_.content_type
)


def prepare_data_and_files(
data: dict[str, Any] | None, files: dict[str, DioptraFile] | None
) -> dict[str, Any]:
"""Prepare the data and files for the request.
Args:
data: A dictionary to send in the body of the request as part of a multipart
form.
files: Dictionary of "name": DioptraFile pairs to be uploaded.
Returns:
A dictionary containing the prepared data and files dictionary.
"""
merged: dict[str, Any] = {}

if data is not None:
merged = merged | data

if files is not None:
merged = merged | {
name: format_file_for_request(file_) for name, file_ in files.items()
}

return merged


class DioptraFlaskClientSession(DioptraSession[DioptraResponseProtocol]):
"""
The interface for communicating with the Dioptra API using the FlaskClient.
Expand Down Expand Up @@ -173,6 +216,8 @@ def make_request(
url: str,
params: dict[str, Any] | None = None,
json_: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
files: dict[str, DioptraFile] | None = None,
) -> DioptraResponseProtocol:
"""Make a request to the API.
Expand All @@ -183,6 +228,10 @@ def make_request(
params: The query parameters to include in the request. Optional, defaults
to None.
json_: The JSON data to include in the request. Optional, defaults to None.
data: A dictionary to send in the body of the request as part of a
multipart form. Optional, defaults to None.
files: Dictionary of "name": DioptraFile pairs to be uploaded. Optional,
defaults to None.
Returns:
The response from the API.
Expand Down Expand Up @@ -212,12 +261,41 @@ def make_request(
method = methods_registry[method_name]
method_kwargs: dict[str, Any] = {"follow_redirects": True}

if method_name != "post":
if data:
raise IllegalArgumentError(
"Illegal value for data (reason: data is only supported for POST "
f"requests): {data}."
)

if files:
raise IllegalArgumentError(
"Illegal value for files (reason: files is only supported for POST "
f"requests): {files}."
)

if json_:
if data:
raise IllegalArgumentError(
"Illegal value for json_ (reason: json_ is not supported if data "
f"is not None): {json_}."
)

if files:
raise IllegalArgumentError(
"Illegal value for json_ (reason: json_ is not supported if files "
f"is not None): {json_}."
)

method_kwargs["json"] = json_

if params:
method_kwargs["query_string"] = params

if data or files:
merged_data = prepare_data_and_files(data=data, files=files)
method_kwargs["data"] = merged_data

return method(url, **method_kwargs)

def download(
Expand Down Expand Up @@ -329,6 +407,8 @@ def post(
*parts,
params: dict[str, Any] | None = None,
json_: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
files: dict[str, DioptraFile] | None = None,
) -> DioptraResponseProtocol:
"""Make a POST request to the API.
Expand All @@ -341,11 +421,17 @@ def post(
params: The query parameters to include in the request. Optional, defaults
to None.
json_: The JSON data to include in the request. Optional, defaults to None.
data: A dictionary to send in the body of the request as part of a
multipart form. Optional, defaults to None.
files: Dictionary of "name": DioptraFile pairs to be uploaded. Optional,
defaults to None.
Returns:
A DioptraTestResponse object.
"""
return self._post(endpoint, *parts, params=params, json_=json_)
return self._post(
endpoint, *parts, params=params, json_=json_, data=data, files=files
)

def delete(
self,
Expand Down

0 comments on commit 8bac475

Please sign in to comment.