diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..c6c915d --- /dev/null +++ b/.coveragerc @@ -0,0 +1,4 @@ +[report] +exclude_lines = + # pragma: no cover + if marshmallow_version < \(3, 13, 0\): diff --git a/README.md b/README.md index efbad89..a2c7e5e 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,18 @@ -[![codecov](https://codecov.io/gh/apryor6/flask_accepts/branch/master/graph/badge.svg)](https://codecov.io/gh/apryor6/flask_accepts) +from lib2to3.btm_utils import tokens[![codecov](https://codecov.io/gh/apryor6/flask_accepts/branch/master/graph/badge.svg)](https://codecov.io/gh/apryor6/flask_accepts) [![license](https://img.shields.io/github/license/apryor6/flask_accepts)](https://img.shields.io/github/license/apryor6/flask_accepts) [![code_style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://img.shields.io/badge/code%20style-black-000000.svg) --- -- [flask_accepts](#flask-accepts) +- [flask_accepts](#flask_accepts) - [Installation](#installation) - [Basic usage](#basic-usage) - - [Usage with "vanilla Flask"](#usage-with--vanilla-flask-) + - [Usage with "vanilla Flask"](#usage-with-vanilla-flask) * [Usage with Marshmallow schemas](#usage-with-marshmallow-schemas) - [Marshmallow validators](#marshmallow-validators) - [Default values](#default-values) + * [ Returning Different Response Schemas](#returning-different-response-schemas) + - [Pass-through of arbitrary status codes](#pass-through-of-arbitrary-status-codes) * [Automatic Swagger documentation](#automatic-swagger-documentation) - [Defining the model name](#defining-the-model-name) - [Error handling](#error-handling) @@ -187,6 +189,71 @@ You can provide any of the built-in validators to Marshmallow schemas. See [here Default values provided to Marshmallow schemas will be internally mapped and displayed in the Swagger documentation. See [this example](https://github.com/apryor6/flask_accepts/blob/master/examples/default_values.py) for a usage of `flask_accepts` with nested Marshmallow schemas and default values that will display correctly in Swagger. +## Returning Different Response Schemas + +In real world scenarios things dont always go to plan and you may need to return an error code with your response data +or a different schema entirely (eg Internal Serrver Errors) + +the `responds` decorator accepts a dictionary of response codes and their associated schema. +when returning response data simply provied the status code as you would for a standard Flask response. +The response will then be loaded into the correct schema. + +**Example:** +```python + +class LoginSchema(Schema): + username = fields.String() + password = fields.String() + +class TokenSchema(Schema): + access_token = fields.String() + id_token = fields.String() + refresh_token = fields.String() + +class ErrorSchema(Schema): + error_code = fields.Integer() + errors = fields.List() + +@api.route("/restx/update_user") +class LoginResource(Resource): + alt_schemas = {401: ErrorSchema} + @accepts(schema=LoginSchema, api=api) + @responds(schema=TokenSchema, alt_schemas=alt_schemas, api=api) + def post(self): + payload = request.parsed_obj + username = payload.username + password = payload.password + + tokens = self.attempt_login(username, password) + if tokens is None: + return {"error_code": 8001, "errors": ["invalid username or password"]}, 401 + + return tokens +``` + +### Pass-through of arbitrary status codes +You can also provide a status code that does not have an associated schema +Also works if there are no alternate schemas set. + +The response data will be loaded into the default schema and the provided status code passed through. +An example of this usage might be returning a 422 for bad/invalid data. +```python + +class ResponseSchema(Schema): + response = fields.String() + errors = fields.List() + +@api.route("/restx/update_user") +class UserResource(Resource): + @responds(schema=ResponseSchema, api=api) + def post(self): + # check data + if not data_valid: + return {"response": None, "errors": ["invalid user id"]}, 422 + + return {"response": "user updated", "errors": []} +``` + ## Automatic Swagger documentation The `accepts` decorator will automatically enable Swagger by internally adding the `@api.expects` decorator. If you have provided positional arguments to `accepts`, this involves generating the corresponding `api.parser()` (which is a `reqparse.RequestParser` that includes the Swagger context). If you provide a Marshmallow Schema, an equivalent `api.model` is generated and passed to `@api.expect`. These two can be mixed-and-matched, and the documentation will update accordingly. diff --git a/dev-requirements.txt b/dev-requirements.txt index 9da31ba..eb219b4 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,22 +1,23 @@ -aniso8601==8.0.0 -attrs==19.3.0 -Click==7.0 -Flask==1.0.2 -flask-marshmallow==0.10.1 -flask-restx==0.5.1 -itsdangerous==1.1.0 -Jinja2==2.11.3 -jsonschema==3.2.0 -MarkupSafe==1.1.1 -marshmallow==3.2.2 -more-itertools==8.0.0 -packaging==19.2 -pluggy==0.13.1 -py==1.10.0 -pyparsing==2.4.5 -pyrsistent==0.15.6 -pytest==5.3.1 -pytz==2019.3 -six==1.13.0 -wcwidth==0.1.7 -Werkzeug==0.16.0 +aniso8601==9.0.1 +attrs==21.4.0 +click==8.1.3 +flask>=2,<3 +flask-restx>=1,<2 +flask-marshmallow==0.14.0 +itsdangerous==2.1.2 +jinja2==3.1.2 +jsonschema==4.7.2 +MarkupSafe==2.1.1 +marshmallow==3.17.0 +more-itertools==8.13.0 +packaging==21.3 +pluggy==1.0.0 +py==1.11.0 +pyparsing==3.0.9 +pyrsistent==0.18.1 +pytest==7.1.2 +pytz==2022.1 +six==1.16.0 +wcwidth==0.2.5 +werkzeug>=2,<3; python_version < '3.8' +werkzeug>=3,<4; python_version >= '3.8' diff --git a/examples/default_values.py b/examples/default_values.py index d209db5..8a4980d 100644 --- a/examples/default_values.py +++ b/examples/default_values.py @@ -1,4 +1,3 @@ -import datetime from dataclasses import dataclass from marshmallow import fields, Schema, post_load from flask import Flask, jsonify, request @@ -6,17 +5,17 @@ class CogSchema(Schema): - cog_foo = fields.String(default="cog") - cog_baz = fields.Integer(default=999) + cog_foo = fields.String(dump_default="cog") + cog_baz = fields.Integer(dump_default=999) class WidgetSchema(Schema): - foo = fields.String(default="test string") - baz = fields.Integer(default=42) - flag = fields.Bool(default=False) - date = fields.Date(default="01-01-1900") - dec = fields.Decimal(default=42.42) - dct = fields.Dict(default={"key": "value"}) + foo = fields.String(dump_default="test string") + baz = fields.Integer(dump_default=42) + flag = fields.Bool(dump_default=False) + date = fields.Date(dump_default="01-01-1900") + dec = fields.Decimal(dump_default=42.42) + dct = fields.Dict(dump_default={"key": "value"}) cog = fields.Nested(CogSchema) diff --git a/examples/marshmallow_example.py b/examples/marshmallow_example.py index 8ef0821..ceec07d 100644 --- a/examples/marshmallow_example.py +++ b/examples/marshmallow_example.py @@ -11,8 +11,8 @@ class Widget: class WidgetSchema(Schema): - foo = fields.String(default="test value") - baz = fields.Integer(default=422) + foo = fields.String(dump_default="test value") + baz = fields.Integer(dump_default=422) @post_load def make(self, data, **kwargs): diff --git a/examples/nested_schemas.py b/examples/nested_schemas.py index 4363a9e..80bba0e 100644 --- a/examples/nested_schemas.py +++ b/examples/nested_schemas.py @@ -10,17 +10,17 @@ class CogSchema(Schema): - cog_foo = fields.String(default="cog") - cog_baz = fields.Integer(default=999) + cog_foo = fields.String(dump_default="cog") + cog_baz = fields.Integer(dump_default=999) class WidgetSchema(Schema): - foo = fields.String(default="test string") - baz = fields.Integer(default=42) - flag = fields.Bool(default=False) - date = fields.Date(default="01-01-1900") - dec = fields.Decimal(default=42.42) - dct = fields.Dict(default={"key": "value"}) + foo = fields.String(dump_default="test string") + baz = fields.Integer(dump_default=42) + flag = fields.Bool(dump_default=False) + date = fields.Date(dump_default="01-01-1900") + dec = fields.Decimal(dump_default=42.42) + dct = fields.Dict(dump_default={"key": "value"}) cog = fields.Nested(CogSchema) diff --git a/flask_accepts/decorators/decorators.py b/flask_accepts/decorators/decorators.py index 84d4f18..6d5d799 100644 --- a/flask_accepts/decorators/decorators.py +++ b/flask_accepts/decorators/decorators.py @@ -1,5 +1,5 @@ from collections import OrderedDict -from typing import Type, Union +from typing import Type, Union, Dict from flask import jsonify from werkzeug.wrappers import Response from werkzeug.exceptions import BadRequest, InternalServerError @@ -126,7 +126,7 @@ def inner(*args, **kwargs): # Handle Marshmallow schema for request body if schema: try: - obj = schema.load(request.get_json(force=True)) + obj = schema.load(request.get_json(force=True) or {}) request.parsed_obj = obj except ValidationError as ex: schema_error = ex.messages @@ -135,9 +135,9 @@ def inner(*args, **kwargs): f"Error parsing request body: {schema_error}" ) if hasattr(error, "data"): - error.data["errors"].update({"schema_errors": schema_error}) + error.data["errors"].update(schema_error) else: - error.data = {"schema_errors": schema_error} + error.data = {"errors": schema_error} # Handle Marshmallow schema for query params if query_params_schema: @@ -155,9 +155,9 @@ def inner(*args, **kwargs): f"Error parsing query params: {schema_error}" ) if hasattr(error, "data"): - error.data["errors"].update({"schema_errors": schema_error}) + error.data["errors"].update(schema_error) else: - error.data = {"schema_errors": schema_error} + error.data = {"errors": schema_error} # Handle Marshmallow schema for headers if headers_schema: @@ -175,9 +175,9 @@ def inner(*args, **kwargs): f"Error parsing headers: {schema_error}" ) if hasattr(error, "data"): - error.data["errors"].update({"schema_errors": schema_error}) + error.data["errors"].update(schema_error) else: - error.data = {"schema_errors": schema_error} + error.data = {"errors": schema_error} # Handle Marshmallow schema for form data if form_schema: @@ -195,9 +195,9 @@ def inner(*args, **kwargs): f"Error parsing form data: {schema_error}" ) if hasattr(error, "data"): - error.data["errors"].update({"schema_errors": schema_error}) + error.data["errors"].update(schema_error) else: - error.data = {"schema_errors": schema_error} + error.data = {"errors": schema_error} # If any parsing produced an error, combine them and re-raise if error: @@ -231,7 +231,8 @@ def inner(*args, **kwargs): def responds( *args, model_name: str = None, - schema=None, + schema: Union[Schema, Type[Schema]] = None, + alt_schemas: Dict[int, Union[Schema, Type[Schema]]] = None, many: bool = False, api=None, envelope=None, @@ -250,6 +251,7 @@ def responds( Args: schema (bool, optional): Marshmallow schema with which to serialize the output of the wrapped function. + alt_schemas (dict, optional): Dict of alternate schemas to use based on the status_code many (bool, optional): (DEPRECATED) The Marshmallow schema `many` parameter, which will return a list of the corresponding schema objects when set to True. @@ -292,17 +294,29 @@ def decorator(func): @wraps(func) def inner(*args, **kwargs): + nonlocal schema + nonlocal status_code + rv = func(*args, **kwargs) # If a Flask response has been made already, it is passed through unchanged if isinstance(rv, Response): return rv - if schema: - serialized = schema.dump(rv) + + resp_schema = schema + # allow overriding the status code passed to Flask + if isinstance(rv, tuple): + rv, status_code = rv + if alt_schemas and status_code in alt_schemas: + # override the default response schema + resp_schema = _get_or_create_schema(alt_schemas[status_code]) + + if resp_schema: + serialized = resp_schema.dump(rv) # Validate data if asked to (throws) if validate: - errs = schema.validate(serialized) + errs = resp_schema.validate(serialized) if errs: raise InternalServerError( description="Server attempted to return invalid data" @@ -338,6 +352,7 @@ def remove_none(obj): return jsonify(serialized), status_code return serialized, status_code + nonlocal schema # Add Swagger if api and use_swagger and _IS_METHOD: if schema: diff --git a/flask_accepts/test/__init__.py b/flask_accepts/tests/__init__.py similarity index 100% rename from flask_accepts/test/__init__.py rename to flask_accepts/tests/__init__.py diff --git a/flask_accepts/test/fixtures.py b/flask_accepts/tests/fixtures.py similarity index 100% rename from flask_accepts/test/fixtures.py rename to flask_accepts/tests/fixtures.py diff --git a/flask_accepts/decorators/decorators_test.py b/flask_accepts/tests/test_accepts.py similarity index 59% rename from flask_accepts/decorators/decorators_test.py rename to flask_accepts/tests/test_accepts.py index bbdfb0a..c93490d 100644 --- a/flask_accepts/decorators/decorators_test.py +++ b/flask_accepts/tests/test_accepts.py @@ -1,11 +1,10 @@ -from flask import request +from flask import jsonify, request from flask_restx import Resource, Api from marshmallow import Schema, fields -from werkzeug.datastructures import MultiDict +from werkzeug.exceptions import InternalServerError from flask_accepts.decorators import accepts, responds -from flask_accepts.decorators.decorators import _convert_multidict_values_to_schema -from flask_accepts.test.fixtures import app, client # noqa +from flask_accepts.tests.fixtures import app, client # noqa def test_arguments_are_added_to_request(app, client): # noqa @@ -53,11 +52,11 @@ class TestResource(Resource): def post(self): assert request.parsed_obj assert request.parsed_obj["_id"] == 42 - assert request.parsed_obj["name"] == "test name" + assert request.parsed_obj["name"] == "tests name" return "success" with client as cl: - resp = cl.post("/test?foo=3", json={"_id": 42, "name": "test name"}) + resp = cl.post("/test?foo=3", json={"_id": 42, "name": "tests name"}) assert resp.status_code == 200 @@ -81,11 +80,11 @@ class TestResource(Resource): def post(self): assert request.parsed_obj assert request.parsed_obj["_id"] == 42 - assert request.parsed_obj["name"] == "test name" + assert request.parsed_obj["name"] == "tests name" return "success" with client as cl: - resp = cl.post("/test?foo=3", json={"_id": 42, "name": "test name"}) + resp = cl.post("/test?foo=3", json={"_id": 42, "name": "tests name"}) assert resp.status_code == 200 @@ -112,10 +111,10 @@ def post(self): with client as cl: resp = cl.post( "/test?foo=3", - json={"_id": "this is not an integer and will error", "name": "test name"}, + json={"_id": "this is not an integer and will error", "name": "tests name"}, ) assert resp.status_code == 400 - assert "Not a valid integer." in resp.json["schema_errors"]["_id"] + assert "Not a valid integer." in resp.json["errors"]["_id"] def test_validation_errors_from_all_added_to_request_with_Resource_and_schema( @@ -142,11 +141,11 @@ def post(self): with client as cl: resp = cl.post( "/test?foo=not_int", - json={"_id": "this is not an integer and will error", "name": "test name"}, + json={"_id": "this is not an integer and will error", "name": "tests name"}, ) assert resp.status_code == 400 - assert "Not a valid integer." in resp.json["errors"]["schema_errors"]["_id"] + assert "Not a valid integer." in resp.json["errors"]["_id"] def test_dict_arguments_are_correctly_added(app, client): # noqa @@ -404,7 +403,6 @@ class TestSchema(Schema): class TestResource(Resource): @accepts("TestSchema", form_schema=TestSchema, api=api) def post(self): - assert request.parsed_form["foo"] == 3 assert request.parsed_args["foo"] == 3 return "success" @@ -513,7 +511,7 @@ class QueryParamsSchema(Schema): class HeadersSchema(Schema): Header = fields.Integer(required=True) - + class FormSchema(Schema): form = fields.String(required=True) @@ -560,272 +558,6 @@ def post(self): obj = resp.json assert obj == [{"_id": 42, "name": "Jon Snow"}] - -def test_responds(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - - @api.route("/test") - class TestResource(Resource): - @responds(schema=TestSchema, api=api) - def get(self): - obj = {"_id": 42, "name": "Jon Snow"} - return obj - - with client as cl: - resp = cl.get("/test") - obj = resp.json - assert obj["_id"] == 42 - assert obj["name"] == "Jon Snow" - - -def test_respond_schema_instance(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - - @api.route("/test") - class TestResource(Resource): - @responds(schema=TestSchema(), api=api) - def get(self): - obj = {"_id": 42, "name": "Jon Snow"} - return obj - - with client as cl: - resp = cl.get("/test") - obj = resp.json - assert obj["_id"] == 42 - assert obj["name"] == "Jon Snow" - - -def test_respond_schema_instance_respects_exclude(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - - @api.route("/test") - class TestResource(Resource): - @responds(schema=TestSchema(exclude=("_id",)), api=api) - def get(self): - obj = {"_id": 42, "name": "Jon Snow"} - return obj - - with client as cl: - resp = cl.get("/test") - obj = resp.json - assert "_id" not in obj - assert obj["name"] == "Jon Snow" - - -def test_respond_schema_respects_many(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - - @api.route("/test") - class TestResource(Resource): - @responds(schema=TestSchema, many=True, api=api) - def get(self): - obj = [{"_id": 42, "name": "Jon Snow"}] - return obj - - with client as cl: - resp = cl.get("/test") - obj = resp.json - assert obj == [{"_id": 42, "name": "Jon Snow"}] - - -def test_respond_schema_instance_respects_many(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - - @api.route("/test") - class TestResource(Resource): - @responds(schema=TestSchema(many=True), api=api) - def get(self): - obj = [{"_id": 42, "name": "Jon Snow"}] - return obj - - with client as cl: - resp = cl.get("/test") - obj = resp.json - assert obj == [{"_id": 42, "name": "Jon Snow"}] - - -def test_responds_regular_route(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - @app.route("/test", methods=["GET"]) - @responds(schema=TestSchema) - def get(): - obj = {"_id": 42, "name": "Jon Snow"} - return obj - - with client as cl: - resp = cl.get("/test") - obj = resp.json - assert obj["_id"] == 42 - assert obj["name"] == "Jon Snow" - - -def test_responds_passes_raw_responses_through_untouched(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - - @api.route("/test") - class TestResource(Resource): - @responds(schema=TestSchema, api=api) - def get(self): - from flask import make_response, Response - - obj = {"_id": 42, "name": "Jon Snow"} - return Response("A prebuild response that won't be serialised", 201) - - with client as cl: - resp = cl.get("/test") - assert resp.status_code == 201 - - -def test_responds_with_parser(app, client): # noqa - - api = Api(app) - - @api.route("/test") - class TestResource(Resource): - @responds( - "King", - dict(name="_id", type=int), - dict(name="name", type=str), - dict(name="value", type=float), - dict(name="status", choices=("alive", "dead")), - dict(name="todos", action="append"), - api=api, - ) - def get(self): - from flask import make_response, Response - - return { - "_id": 42, - "name": "Jon Snow", - "value": 100.0, - "status": "alive", - "todos": ["one", "two"], - } - - with client as cl: - resp = cl.get("/test") - assert resp.status_code == 200 - assert resp.json == { - "_id": 42, - "name": "Jon Snow", - "value": 100.0, - "status": "alive", - "todos": ["one", "two"], - } - - -def test_responds_respects_status_code(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - - @api.route("/test") - class TestResource(Resource): - @responds(schema=TestSchema, api=api, status_code=999) - def get(self): - from flask import make_response, Response - - obj = {"_id": 42, "name": "Jon Snow"} - return obj - - with client as cl: - resp = cl.get("/test") - assert resp.status_code == 999 - - -def test_responds_respects_envelope(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - - @api.route("/test") - class TestResource(Resource): - @responds(schema=TestSchema, api=api, envelope='test-data') - def get(self): - from flask import make_response, Response - - obj = {"_id": 42, "name": "Jon Snow"} - return obj - - with client as cl: - resp = cl.get("/test") - assert resp.status_code == 200 - assert resp.json == {'test-data': {'_id': 42, 'name': 'Jon Snow'}} - - -def test_responds_skips_none_false(app, client): - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - - @api.route("/test") - class TestResource(Resource): - @responds(schema=TestSchema, api=api) - def get(self): - return {"_id": 42, "name": None} - - with client as cl: - resp = cl.get("/test") - assert resp.status_code == 200 - assert resp.json == {'_id': 42, 'name': None} - - -def test_responds_with_nested_skips_none_true(app, client): - class NestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - class TestSchema(Schema): - name = fields.String() - child = fields.Nested(NestSchema) - - api = Api(app) - - @api.route("/test") - class TestResource(Resource): - @responds(schema=TestSchema, api=api, skip_none=True, many=True) - def get(self): - return [{"name": None, "child": {"_id": 42, "name": None}}] - - with client as cl: - resp = cl.get("/test") - assert resp.status_code == 200 - assert resp.json == [{"child": {'_id': 42}}] - - def test_accepts_with_nested_schema(app, client): # noqa class TestSchema(Schema): _id = fields.Integer() @@ -847,14 +579,14 @@ class TestResource(Resource): ) def post(self): assert request.parsed_obj - assert request.parsed_obj["child"] == {"_id": 42, "name": "test name"} - assert request.parsed_obj["name"] == "test host" + assert request.parsed_obj["child"] == {"_id": 42, "name": "tests name"} + assert request.parsed_obj["name"] == "tests host" return "success" with client as cl: resp = cl.post( "/test?foo=3", - json={"name": "test host", "child": {"_id": 42, "name": "test name"}}, + json={"name": "tests host", "child": {"_id": 42, "name": "tests name"}}, ) assert resp.status_code == 200 @@ -886,23 +618,23 @@ def post(self): assert request.parsed_obj assert request.parsed_obj["child"]["child"] == { "_id": 42, - "name": "test name", + "name": "tests name", } assert request.parsed_obj["child"] == { - "name": "test host", - "child": {"_id": 42, "name": "test name"}, + "name": "tests host", + "child": {"_id": 42, "name": "tests name"}, } - assert request.parsed_obj["name"] == "test host host" + assert request.parsed_obj["name"] == "tests host host" return "success" with client as cl: resp = cl.post( "/test?foo=3", json={ - "name": "test host host", + "name": "tests host host", "child": { - "name": "test host", - "child": {"_id": 42, "name": "test name"}, + "name": "tests host", + "child": {"_id": 42, "name": "tests name"}, }, }, ) @@ -910,10 +642,6 @@ def post(self): def test_responds_with_validate(app, client): # noqa - import pytest - from flask import jsonify - from werkzeug.exceptions import InternalServerError - class TestSchema(Schema): _id = fields.Integer(required=True) name = fields.String(required=True) @@ -936,10 +664,6 @@ def get(): def test_responds_with_validate(app, client): # noqa - import pytest - from flask import jsonify - from werkzeug.exceptions import InternalServerError - class TestDataObj: def __init__(self, wrong_field, name): self.wrong_field = wrong_field @@ -965,148 +689,3 @@ def get(): obj = resp.json assert resp.status_code == 500 assert resp.json == {"message": "Server attempted to return invalid data"} - - -def test_multidict_single_values_interpreted_correctly(app, client): # noqa - class TestSchema(Schema): - name = fields.String(required=True) - - multidict = MultiDict([("name", "value"), ("new_value", "still_here")]) - result = _convert_multidict_values_to_schema(multidict, TestSchema()) - - # `name` should be left a single value - assert result["name"] == "value" - - # `new_value` should *not* be removed here, even though it"s not in the - # schema. - assert result["new_value"] == "still_here" - - # Also makes sure that if multiple values are found in the multidict, then - # only the first one is returned. - multidict = MultiDict([ - ("name", "value"), - ("name", "value2"), - ]) - result = _convert_multidict_values_to_schema(multidict, TestSchema()) - assert result["name"] == "value" - - -def test_multidict_list_values_interpreted_correctly(app, client): # noqa - class TestSchema(Schema): - name = fields.List(fields.String(), required=True) - - multidict = MultiDict([ - ("name", "value"), - ("new_value", "still_here") - ]) - result = _convert_multidict_values_to_schema(multidict, TestSchema()) - - # `name` should be converted to a list. - assert result["name"] == ["value"] - - # `new_value` should *not* be removed here, even though it"s not in the schema. - assert result["new_value"] == "still_here" - - # Also makes sure handling a list with >1 values also works. - multidict = MultiDict([ - ("name", "value"), - ("name", "value2"), - ]) - result = _convert_multidict_values_to_schema(multidict, TestSchema()) - assert result["name"] == ["value", "value2"] - - -def test_no_schema_generates_correct_swagger(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - route = "/test" - - @api.route(route) - class TestResource(Resource): - @responds(api=api, status_code=201, description="My description") - def post(self): - obj = [{"_id": 42, "name": "Jon Snow"}] - return obj - - with client as cl: - cl.post(route, data='[{"_id": 42, "name": "Jon Snow"}]', content_type='application/json') - route_docs = api.__schema__["paths"][route]["post"] - - responses_docs = route_docs['responses']['201'] - - assert responses_docs['description'] == "My description" - - -def test_schema_generates_correct_swagger(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - route = "/test" - - @api.route(route) - class TestResource(Resource): - @accepts(model_name="MyRequest", schema=TestSchema(many=False), api=api) - @responds(model_name="MyResponse", schema=TestSchema(many=False), api=api, description="My description") - def post(self): - obj = {"_id": 42, "name": "Jon Snow"} - return obj - - with client as cl: - cl.post(route, data='{"_id": 42, "name": "Jon Snow"}', content_type='application/json') - route_docs = api.__schema__["paths"][route]["post"] - responses_docs = route_docs['responses']['200'] - - assert responses_docs['description'] == "My description" - assert responses_docs['schema'] == {'$ref': '#/definitions/MyResponse'} - assert route_docs['parameters'][0]['schema'] == {'$ref': '#/definitions/MyRequest'} - - -def test_schema_generates_correct_swagger_for_many(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - route = "/test" - - @api.route(route) - class TestResource(Resource): - @accepts(schema=TestSchema(many=True), api=api) - @responds(schema=TestSchema(many=True), api=api, description="My description") - def post(self): - obj = [{"_id": 42, "name": "Jon Snow"}] - return obj - - with client as cl: - resp = cl.post(route, data='[{"_id": 42, "name": "Jon Snow"}]', content_type='application/json') - route_docs = api.__schema__["paths"][route]["post"] - assert route_docs['responses']['200']['schema'] == {"type": "array", "items": {"$ref": "#/definitions/Test"}} - assert route_docs['parameters'][0]['schema'] == {"type": "array", "items": {"$ref": "#/definitions/Test"}} - - -def test_swagger_respects_existing_response_docs(app, client): # noqa - class TestSchema(Schema): - _id = fields.Integer() - name = fields.String() - - api = Api(app) - route = "/test" - - @api.route(route) - class TestResource(Resource): - @responds(schema=TestSchema(many=True), api=api, description="My description") - @api.doc(responses={401: "Not Authorized", 404: "Not Found"}) - def get(self): - return [{"_id": 42, "name": "Jon Snow"}] - - with client as cl: - cl.get(route) - route_docs = api.__schema__["paths"][route]["get"] - assert route_docs["responses"]["200"]["description"] == "My description" - assert route_docs["responses"]["401"]["description"] == "Not Authorized" - assert route_docs["responses"]["404"]["description"] == "Not Found" diff --git a/flask_accepts/tests/test_decorators.py b/flask_accepts/tests/test_decorators.py new file mode 100644 index 0000000..241b769 --- /dev/null +++ b/flask_accepts/tests/test_decorators.py @@ -0,0 +1,102 @@ +from flask_restx import Resource, Api +from marshmallow import Schema, fields +from werkzeug.datastructures import MultiDict + +from flask_accepts.decorators import accepts, responds +from flask_accepts.decorators.decorators import _convert_multidict_values_to_schema +from flask_accepts.tests.fixtures import app, client # noqa + + +def test_schema_generates_correct_swagger(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + route = "/test" + + @api.route(route) + class TestResource(Resource): + @accepts(model_name="MyRequest", schema=TestSchema(many=False), api=api) + @responds(model_name="MyResponse", schema=TestSchema(many=False), api=api, description="My description") + def post(self): + obj = {"_id": 42, "name": "Jon Snow"} + return obj + + with client as cl: + cl.post(route, data='{"_id": 42, "name": "Jon Snow"}', content_type='application/json') + route_docs = api.__schema__["paths"][route]["post"] + responses_docs = route_docs['responses']['200'] + + assert responses_docs['description'] == "My description" + assert responses_docs['schema'] == {'$ref': '#/definitions/MyResponse'} + assert route_docs['parameters'][0]['schema'] == {'$ref': '#/definitions/MyRequest'} + +def test_schema_generates_correct_swagger_for_many(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + route = "/test" + + @api.route(route) + class TestResource(Resource): + @accepts(schema=TestSchema(many=True), api=api) + @responds(schema=TestSchema(many=True), api=api, description="My description") + def post(self): + obj = [{"_id": 42, "name": "Jon Snow"}] + return obj + + with client as cl: + resp = cl.post(route, data='[{"_id": 42, "name": "Jon Snow"}]', content_type='application/json') + route_docs = api.__schema__["paths"][route]["post"] + assert route_docs['responses']['200']['schema'] == {"type": "array", "items": {"$ref": "#/definitions/Test"}} + assert route_docs['parameters'][0]['schema'] == {"type": "array", "items": {"$ref": "#/definitions/Test"}} + +def test_multidict_single_values_interpreted_correctly(app, client): # noqa + class TestSchema(Schema): + name = fields.String(required=True) + + multidict = MultiDict([("name", "value"), ("new_value", "still_here")]) + result = _convert_multidict_values_to_schema(multidict, TestSchema()) + + # `name` should be left a single value + assert result["name"] == "value" + + # `new_value` should *not* be removed here, even though it"s not in the + # schema. + assert result["new_value"] == "still_here" + + # Also makes sure that if multiple values are found in the multidict, then + # only the first one is returned. + multidict = MultiDict([ + ("name", "value"), + ("name", "value2"), + ]) + result = _convert_multidict_values_to_schema(multidict, TestSchema()) + assert result["name"] == "value" + +def test_multidict_list_values_interpreted_correctly(app, client): # noqa + class TestSchema(Schema): + name = fields.List(fields.String(), required=True) + + multidict = MultiDict([ + ("name", "value"), + ("new_value", "still_here") + ]) + result = _convert_multidict_values_to_schema(multidict, TestSchema()) + + # `name` should be converted to a list. + assert result["name"] == ["value"] + + # `new_value` should *not* be removed here, even though it"s not in the schema. + assert result["new_value"] == "still_here" + + # Also makes sure handling a list with >1 values also works. + multidict = MultiDict([ + ("name", "value"), + ("name", "value2"), + ]) + result = _convert_multidict_values_to_schema(multidict, TestSchema()) + assert result["name"] == ["value", "value2"] diff --git a/flask_accepts/tests/test_responds.py b/flask_accepts/tests/test_responds.py new file mode 100644 index 0000000..6bc7ce5 --- /dev/null +++ b/flask_accepts/tests/test_responds.py @@ -0,0 +1,516 @@ +import json + +from attr import dataclass +from flask import request, Response, jsonify +from flask_restx import Resource, Api +from marshmallow import Schema, fields +from werkzeug.exceptions import InternalServerError + +from flask_accepts.decorators import accepts, responds +from flask_accepts.tests.fixtures import app, client # noqa + + +def test_responds(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @responds(schema=TestSchema, api=api) + def get(self): + obj = {"_id": 42, "name": "Jon Snow"} + return obj + + with client as cl: + resp = cl.get("/test") + obj = resp.json + assert obj["_id"] == 42 + assert obj["name"] == "Jon Snow" + + +def test_respond_schema_instance(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @responds(schema=TestSchema(), api=api) + def get(self): + obj = {"_id": 42, "name": "Jon Snow"} + return obj + + with client as cl: + resp = cl.get("/test") + obj = resp.json + assert obj["_id"] == 42 + assert obj["name"] == "Jon Snow" + + +def test_respond_schema_instance_respects_exclude(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @responds(schema=TestSchema(exclude=("_id",)), api=api) + def get(self): + obj = {"_id": 42, "name": "Jon Snow"} + return obj + + with client as cl: + resp = cl.get("/test") + obj = resp.json + assert "_id" not in obj + assert obj["name"] == "Jon Snow" + + +def test_respond_schema_respects_many(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @responds(schema=TestSchema, many=True, api=api) + def get(self): + obj = [{"_id": 42, "name": "Jon Snow"}] + return obj + + with client as cl: + resp = cl.get("/test") + obj = resp.json + assert obj == [{"_id": 42, "name": "Jon Snow"}] + + +def test_respond_schema_instance_respects_many(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @responds(schema=TestSchema(many=True), api=api) + def get(self): + obj = [{"_id": 42, "name": "Jon Snow"}] + return obj + + with client as cl: + resp = cl.get("/test") + obj = resp.json + assert obj == [{"_id": 42, "name": "Jon Snow"}] + + +def test_responds_regular_route(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + @app.route("/test", methods=["GET"]) + @responds(schema=TestSchema) + def get(): + obj = {"_id": 42, "name": "Jon Snow"} + return obj + + with client as cl: + resp = cl.get("/test") + obj = resp.json + assert obj["_id"] == 42 + assert obj["name"] == "Jon Snow" + + +def test_responds_passes_raw_responses_through_untouched(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @responds(schema=TestSchema, api=api) + def get(self): + + + obj = {"_id": 42, "name": "Jon Snow"} + return Response("A prebuild response that won't be serialised", 201) + + with client as cl: + resp = cl.get("/test") + assert resp.status_code == 201 + + +def test_responds_with_parser(app, client): # noqa + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @responds( + "King", + dict(name="_id", type=int), + dict(name="name", type=str), + dict(name="value", type=float), + dict(name="status", choices=("alive", "dead")), + dict(name="todos", action="append"), + api=api, + ) + def get(self): + return { + "_id": 42, + "name": "Jon Snow", + "value": 100.0, + "status": "alive", + "todos": ["one", "two"], + } + + with client as cl: + resp = cl.get("/test") + assert resp.status_code == 200 + assert resp.json == { + "_id": 42, + "name": "Jon Snow", + "value": 100.0, + "status": "alive", + "todos": ["one", "two"], + } + + +def test_responds_respects_status_code(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @responds(schema=TestSchema, api=api, status_code=999) + def get(self): + obj = {"_id": 42, "name": "Jon Snow"} + return obj + + with client as cl: + resp = cl.get("/test") + assert resp.status_code == 999 + +def test_responds_respects_custom_status_code(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @responds(schema=TestSchema, api=api, status_code=999) + def get(self): + obj = {"_id": 42, "name": "Jon Snow"} + return obj, 888 + + with client as cl: + resp = cl.get("/test") + assert resp.status_code == 888 + +def test_responds_respects_envelope(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @responds(schema=TestSchema, api=api, envelope='tests-data') + def get(self): + obj = {"_id": 42, "name": "Jon Snow"} + return obj + + with client as cl: + resp = cl.get("/test") + assert resp.status_code == 200 + assert resp.json == {'tests-data': {'_id': 42, 'name': 'Jon Snow'}} + + +def test_responds_skips_none_false(app, client): + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @responds(schema=TestSchema, api=api) + def get(self): + return {"_id": 42, "name": None} + + with client as cl: + resp = cl.get("/test") + assert resp.status_code == 200 + assert resp.json == {'_id': 42, 'name': None} + + +def test_responds_with_nested_skips_none_true(app, client): + class NestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + class TestSchema(Schema): + name = fields.String() + child = fields.Nested(NestSchema) + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @responds(schema=TestSchema, api=api, skip_none=True, many=True) + def get(self): + return [{"name": None, "child": {"_id": 42, "name": None}}] + + with client as cl: + resp = cl.get("/test") + assert resp.status_code == 200 + assert resp.json == [{"child": {'_id': 42}}] + + +def test_accepts_with_nested_schema(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + class HostSchema(Schema): + name = fields.String() + child = fields.Nested(TestSchema) + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @accepts( + "Foo", + dict(name="foo", type=int, help="An important foo"), + schema=HostSchema, + api=api, + ) + def post(self): + assert request.parsed_obj + assert request.parsed_obj["child"] == {"_id": 42, "name": "tests name"} + assert request.parsed_obj["name"] == "tests host" + return "success" + + with client as cl: + resp = cl.post( + "/test?foo=3", + json={"name": "tests host", "child": {"_id": 42, "name": "tests name"}}, + ) + assert resp.status_code == 200 + + +def test_accepts_with_twice_nested_schema(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + class HostSchema(Schema): + name = fields.String() + child = fields.Nested(TestSchema) + + class HostHostSchema(Schema): + name = fields.String() + child = fields.Nested(HostSchema) + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + @accepts( + "Foo", + dict(name="foo", type=int, help="An important foo"), + schema=HostHostSchema, + api=api, + ) + def post(self): + assert request.parsed_obj + assert request.parsed_obj["child"]["child"] == { + "_id": 42, + "name": "tests name", + } + assert request.parsed_obj["child"] == { + "name": "tests host", + "child": {"_id": 42, "name": "tests name"}, + } + assert request.parsed_obj["name"] == "tests host host" + return "success" + + with client as cl: + resp = cl.post( + "/test?foo=3", + json={ + "name": "tests host host", + "child": { + "name": "tests host", + "child": {"_id": 42, "name": "tests name"}, + }, + }, + ) + assert resp.status_code == 200 + + +def test_responds_with_validate(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer(required=True) + name = fields.String(required=True) + + @app.errorhandler(InternalServerError) + def payload_validation_failure(err): + return jsonify({"message": "Server attempted to return invalid data"}), 500 + + @app.route("/test") + @responds(schema=TestSchema, validate=True) + def get(): + obj = {"wrong_field": 42, "name": "Jon Snow"} + return obj + + with app.test_client() as cl: + resp = cl.get("/test") + obj = resp.json + assert resp.status_code == 500 + assert resp.json == {"message": "Server attempted to return invalid data"} + + +def test_responds_with_validate(app, client): # noqa + class TestDataObj: + def __init__(self, wrong_field, name): + self.wrong_field = wrong_field + self.name = name + + class TestSchema(Schema): + _id = fields.Integer(required=True) + name = fields.String(required=True) + + @app.errorhandler(InternalServerError) + def payload_validation_failure(err): + return jsonify({"message": "Server attempted to return invalid data"}), 500 + + @app.route("/test") + @responds(schema=TestSchema, validate=True) + def get(): + obj = {"wrong_field": 42, "name": "Jon Snow"} + data = TestDataObj(**obj) + return data + + with app.test_client() as cl: + resp = cl.get("/test") + obj = resp.json + assert resp.status_code == 500 + assert resp.json == {"message": "Server attempted to return invalid data"} + + +def test_no_schema_generates_correct_swagger(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + route = "/test" + + @api.route(route) + class TestResource(Resource): + @responds(api=api, status_code=201, description="My description") + def post(self): + obj = [{"_id": 42, "name": "Jon Snow"}] + return obj + + with client as cl: + cl.post(route, data='[{"_id": 42, "name": "Jon Snow"}]', content_type='application/json') + route_docs = api.__schema__["paths"][route]["post"] + + responses_docs = route_docs['responses']['201'] + + assert responses_docs['description'] == "My description" + + +def test_swagger_respects_existing_response_docs(app, client): # noqa + class TestSchema(Schema): + _id = fields.Integer() + name = fields.String() + + api = Api(app) + route = "/test" + + @api.route(route) + class TestResource(Resource): + @responds(schema=TestSchema(many=True), api=api, description="My description") + @api.doc(responses={401: "Not Authorized", 404: "Not Found"}) + def get(self): + return [{"_id": 42, "name": "Jon Snow"}] + + with client as cl: + cl.get(route) + route_docs = api.__schema__["paths"][route]["get"] + assert route_docs["responses"]["200"]["description"] == "My description" + assert route_docs["responses"]["401"]["description"] == "Not Authorized" + assert route_docs["responses"]["404"]["description"] == "Not Found" + +def test_responds_can_use_alt_schema(app, client): # noqa + class DefaultSchema(Schema): + id = fields.Integer() + name = fields.String() + + class ErrorSchema(Schema): + code = fields.String() + error = fields.String() + + class TokenSchema(Schema): + access_token = fields.String() + refresh_token = fields.String() + + api = Api(app) + + @api.route("/test") + class TestResource(Resource): + alt_schemas = { + 888: TokenSchema, + 666: ErrorSchema, + } + @responds(schema=DefaultSchema, api=api, alt_schemas=alt_schemas) + def get(self): + resp_code = int(request.args.get("code")) + + if resp_code == 888: + resp = {"access_token": "test_access_token", "refresh_token": "test_refresh_token"} + elif resp_code == 666: + resp = {"code": "UNKNOWN", "error": "Unhandled Exception"} + else: + resp = {"id": 1234, "name": "Fred Smith"} + + return resp, resp_code + + with client as cl: + # test alternate schema + resp = cl.get("/test?code=666") + assert resp.status_code == 666 + assert resp.json == {"code": "UNKNOWN", "error": "Unhandled Exception"} + + # test different alternate schema + resp = cl.get("/test?code=888") + assert resp.status_code == 888 + assert resp.json == {"access_token": "test_access_token", "refresh_token": "test_refresh_token"} + + # test fallback to default schema with status code passthrough + resp = cl.get("/test?code=401") + assert resp.status_code == 401 + assert resp.json == {"id": 1234, "name": "Fred Smith"} diff --git a/flask_accepts/utils_test.py b/flask_accepts/tests/test_utils.py similarity index 90% rename from flask_accepts/utils_test.py rename to flask_accepts/tests/test_utils.py index 4aa9cde..e78d4ac 100644 --- a/flask_accepts/utils_test.py +++ b/flask_accepts/tests/test_utils.py @@ -8,7 +8,6 @@ from flask import Flask from flask_restx import Api, fields as fr, namespace -# from .utils import unpack_list, unpack_nested import flask_accepts.utils as utils @@ -96,46 +95,40 @@ class IntegerSchema(Schema): def test_get_default_model_name(): - from .utils import get_default_model_name - class TestSchema(Schema): pass - result = get_default_model_name(TestSchema) + result = utils.get_default_model_name(TestSchema) expected = "Test" assert result == expected def test_get_default_model_name_works_with_multiple_schema_in_name(): - from .utils import get_default_model_name - class TestSchemaSchema(Schema): pass - result = get_default_model_name(TestSchemaSchema) + result = utils.get_default_model_name(TestSchemaSchema) expected = "TestSchema" assert result == expected def test_get_default_model_name_that_does_not_end_in_schema(): - from .utils import get_default_model_name - class SomeOtherName(Schema): pass - result = get_default_model_name(SomeOtherName) + result = utils.get_default_model_name(SomeOtherName) expected = "SomeOtherName" assert result == expected def test_get_default_model_name_default_names(): - from .utils import get_default_model_name, num_default_models + from flask_accepts.utils import num_default_models for model_num in range(5): - result = get_default_model_name() + result = utils.get_default_model_name() expected = f"DefaultResponseModel_{model_num + num_default_models}" assert result == expected @@ -199,9 +192,9 @@ class FakeFieldNoRequired(ma.Field): def test__ma_field_to_fr_field_converts_missing_param_to_default_if_present(): @dataclass class FakeFieldWithMissing(ma.Field): - missing: bool + load_default: bool - fr_field_dict = utils._ma_field_to_fr_field(FakeFieldWithMissing(missing=True)) + fr_field_dict = utils._ma_field_to_fr_field(FakeFieldWithMissing(load_default=True)) assert fr_field_dict["default"] is True @dataclass @@ -242,12 +235,12 @@ class FakeFieldNoDescription(ma.Field): def test__ma_field_to_fr_field_converts_default_to_example_if_present(): @dataclass class FakeFieldWithDefault(ma.Field): - default: str + dump_default: str expected_example_value = "test" fr_field_dict = utils._ma_field_to_fr_field( - FakeFieldWithDefault(default=expected_example_value) + FakeFieldWithDefault(dump_default=expected_example_value) ) assert fr_field_dict["example"] == expected_example_value @@ -269,82 +262,70 @@ class FakeFieldWithNoParams(ma.Field): def test_make_type_mapper_works_with_required(): - from flask_accepts.utils import make_type_mapper - app = Flask(__name__) api = Api(app) - mapper = make_type_mapper(fr.Raw) + mapper = utils.make_type_mapper(fr.Raw) result = mapper(ma.Raw(required=True), api=api, model_name="test_model_name", operation="load") assert result.required def test_make_type_mapper_produces_nonrequired_param_by_default(): - from flask_accepts.utils import make_type_mapper - app = Flask(__name__) api = Api(app) - mapper = make_type_mapper(fr.Raw) + mapper = utils.make_type_mapper(fr.Raw) result = mapper(ma.Raw(), api=api, model_name="test_model_name", operation="load") assert not result.required def test__maybe_add_operation_passes_through_if_no_load_only(): - from flask_accepts.utils import _maybe_add_operation - class TestSchema(Schema): _id = ma.Integer() model_name = "TestSchema" operation = "load" - result = _maybe_add_operation(TestSchema(), model_name, operation) + result = utils._maybe_add_operation(TestSchema(), model_name, operation) expected = model_name assert result == expected def test__maybe_add_operation_append_if_load_only(): - from flask_accepts.utils import _maybe_add_operation - class TestSchema(Schema): _id = ma.Integer(load_only=True) model_name = "TestSchema" operation = "load" - result = _maybe_add_operation(TestSchema(), model_name, operation) + result = utils._maybe_add_operation(TestSchema(), model_name, operation) expected = f"{model_name}-load" assert result == expected def test__maybe_add_operation_passes_through_if_no_dump_only(): - from flask_accepts.utils import _maybe_add_operation - class TestSchema(Schema): _id = ma.Integer() model_name = "TestSchema" operation = "dump" - result = _maybe_add_operation(TestSchema(), model_name, operation) + result = utils._maybe_add_operation(TestSchema(), model_name, operation) expected = model_name assert result == expected def test__maybe_add_operation_append_if_dump_only(): - from flask_accepts.utils import _maybe_add_operation - class TestSchema(Schema): _id = ma.Integer(dump_only=True) model_name = "TestSchema" operation = "dump" - result = _maybe_add_operation(TestSchema(), model_name, operation) + result = utils._maybe_add_operation(TestSchema(), model_name, operation) expected = f"{model_name}-dump" assert result == expected diff --git a/flask_accepts/utils.py b/flask_accepts/utils.py index d8e997d..c8d6e47 100644 --- a/flask_accepts/utils.py +++ b/flask_accepts/utils.py @@ -1,8 +1,16 @@ from typing import Optional, Type, Union + from flask_restx import fields as fr, inputs from marshmallow import fields as ma +from marshmallow import __version_info__ as marshmallow_version from marshmallow.schema import Schema, SchemaMeta -import uuid + + +_ma_key_for_fr_example_key = "dump_default" +_ma_key_for_fr_default_key = "load_default" +if marshmallow_version < (3, 13, 0): + _ma_key_for_fr_example_key = "default" + _ma_key_for_fr_default_key = "missing" def unpack_list(val, api, model_name: str = None, operation: str = "dump"): @@ -17,7 +25,7 @@ def unpack_nested(val, api, model_name: str = None, operation: str = "dump"): return unpack_nested_self(val, api, model_name, operation) model_name = get_default_model_name(val.nested) - + if val.many: return fr.List( fr.Nested( @@ -168,8 +176,9 @@ def get_default_model_name(schema: Optional[Union[Schema, Type[Schema]]] = None) def _ma_field_to_fr_field(value: ma.Field) -> dict: fr_field_parameters = {} - if hasattr(value, "default") and type(value.default) != ma.utils._Missing: - fr_field_parameters["example"] = value.default + if hasattr(value, _ma_key_for_fr_example_key) \ + and type(getattr(value, _ma_key_for_fr_example_key)) != ma.utils._Missing: + fr_field_parameters["example"] = getattr(value, _ma_key_for_fr_example_key) if hasattr(value, "required"): fr_field_parameters["required"] = value.required @@ -177,8 +186,9 @@ def _ma_field_to_fr_field(value: ma.Field) -> dict: if hasattr(value, "metadata") and "description" in value.metadata: fr_field_parameters["description"] = value.metadata["description"] - if hasattr(value, "missing") and type(value.missing) != ma.utils._Missing: - fr_field_parameters["default"] = value.missing + if hasattr(value, _ma_key_for_fr_default_key) \ + and type(getattr(value, _ma_key_for_fr_default_key)) != ma.utils._Missing: + fr_field_parameters["default"] = getattr(value, _ma_key_for_fr_default_key) return fr_field_parameters diff --git a/requirements.txt b/requirements.txt index 1ebb6b3..8ad3f9f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,6 @@ -Flask==1.0.2 -flask-restx==0.5.1 +flask>=2,<3; python_version < '3.8' +flask>=3.0; python_version >= '3.8' +flask-restx==1.1; python_version < '3.8' +flask-restx>=1.2; python_version >= '3.8' +werkzeug>=2,<3; python_version < '3.8' +werkzeug>=3,<4; python_version >= '3.8' diff --git a/setup.py b/setup.py index 3119e1d..9ed1760 100644 --- a/setup.py +++ b/setup.py @@ -6,13 +6,15 @@ name="flask_accepts", author='Alan "AJ" Pryor, Jr.', author_email="apryor6@gmail.com", - version="0.18.4", + version="1.0.0", description="Easy, opinionated Flask input/output handling with Flask-restx and Marshmallow", ext_modules=[], packages=find_packages(), install_requires=[ - "marshmallow>=3.0.1", - "flask-restx>=0.2.0", - "Werkzeug" + "marshmallow>=3.17.0", + "flask-restx==1.1.0; python_version < '3.8'", + "flask-restx>=1.2.0; python_version >= '3.8'", + "werkzeug>=2,<3; python_version < '3.8'", + "werkzeug>=3,<4; python_version >= '3.8'", ], )