Skip to content

Commit

Permalink
Merge pull request #123 from circulon/feature/responds_alt_schema_per…
Browse files Browse the repository at this point in the history
…_status_code

Feature: Add other schemas to @resonds
  • Loading branch information
apryor6 authored Sep 17, 2024
2 parents 0556746 + eeb3a62 commit 6e808b8
Show file tree
Hide file tree
Showing 16 changed files with 827 additions and 547 deletions.
4 changes: 4 additions & 0 deletions .coveragerc
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[report]
exclude_lines =
# pragma: no cover
if marshmallow_version < \(3, 13, 0\):
73 changes: 70 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
[![codecov](https://codecov.io/gh/apryor6/flask_accepts/branch/master/graph/badge.svg)](https://codecov.io/gh/apryor6/flask_accepts)
from lib2to3.btm_utils import tokens[![codecov](https://codecov.io/gh/apryor6/flask_accepts/branch/master/graph/badge.svg)](https://codecov.io/gh/apryor6/flask_accepts)
[![license](https://img.shields.io/github/license/apryor6/flask_accepts)](https://img.shields.io/github/license/apryor6/flask_accepts)
[![code_style](https://img.shields.io/badge/code%20style-black-000000.svg)](https://img.shields.io/badge/code%20style-black-000000.svg)

---

- [flask_accepts](#flask-accepts)
- [flask_accepts](#flask_accepts)
- [Installation](#installation)
- [Basic usage](#basic-usage)
- [Usage with "vanilla Flask"](#usage-with--vanilla-flask-)
- [Usage with "vanilla Flask"](#usage-with-vanilla-flask)
* [Usage with Marshmallow schemas](#usage-with-marshmallow-schemas)
- [Marshmallow validators](#marshmallow-validators)
- [Default values](#default-values)
* [ Returning Different Response Schemas](#returning-different-response-schemas)
- [Pass-through of arbitrary status codes](#pass-through-of-arbitrary-status-codes)
* [Automatic Swagger documentation](#automatic-swagger-documentation)
- [Defining the model name](#defining-the-model-name)
- [Error handling](#error-handling)
Expand Down Expand Up @@ -187,6 +189,71 @@ You can provide any of the built-in validators to Marshmallow schemas. See [here

Default values provided to Marshmallow schemas will be internally mapped and displayed in the Swagger documentation. See [this example](https://github.com/apryor6/flask_accepts/blob/master/examples/default_values.py) for a usage of `flask_accepts` with nested Marshmallow schemas and default values that will display correctly in Swagger.

## Returning Different Response Schemas

In real world scenarios things dont always go to plan and you may need to return an error code with your response data
or a different schema entirely (eg Internal Serrver Errors)

the `responds` decorator accepts a dictionary of response codes and their associated schema.
when returning response data simply provied the status code as you would for a standard Flask response.
The response will then be loaded into the correct schema.

**Example:**
```python

class LoginSchema(Schema):
username = fields.String()
password = fields.String()

class TokenSchema(Schema):
access_token = fields.String()
id_token = fields.String()
refresh_token = fields.String()

class ErrorSchema(Schema):
error_code = fields.Integer()
errors = fields.List()

@api.route("/restx/update_user")
class LoginResource(Resource):
alt_schemas = {401: ErrorSchema}
@accepts(schema=LoginSchema, api=api)
@responds(schema=TokenSchema, alt_schemas=alt_schemas, api=api)
def post(self):
payload = request.parsed_obj
username = payload.username
password = payload.password

tokens = self.attempt_login(username, password)
if tokens is None:
return {"error_code": 8001, "errors": ["invalid username or password"]}, 401

return tokens
```

### Pass-through of arbitrary status codes
You can also provide a status code that does not have an associated schema
Also works if there are no alternate schemas set.

The response data will be loaded into the default schema and the provided status code passed through.
An example of this usage might be returning a 422 for bad/invalid data.
```python

class ResponseSchema(Schema):
response = fields.String()
errors = fields.List()

@api.route("/restx/update_user")
class UserResource(Resource):
@responds(schema=ResponseSchema, api=api)
def post(self):
# check data
if not data_valid:
return {"response": None, "errors": ["invalid user id"]}, 422

return {"response": "user updated", "errors": []}
```

## Automatic Swagger documentation

The `accepts` decorator will automatically enable Swagger by internally adding the `@api.expects` decorator. If you have provided positional arguments to `accepts`, this involves generating the corresponding `api.parser()` (which is a `reqparse.RequestParser` that includes the Swagger context). If you provide a Marshmallow Schema, an equivalent `api.model` is generated and passed to `@api.expect`. These two can be mixed-and-matched, and the documentation will update accordingly.
Expand Down
45 changes: 23 additions & 22 deletions dev-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
aniso8601==8.0.0
attrs==19.3.0
Click==7.0
Flask==1.0.2
flask-marshmallow==0.10.1
flask-restx==0.5.1
itsdangerous==1.1.0
Jinja2==2.11.3
jsonschema==3.2.0
MarkupSafe==1.1.1
marshmallow==3.2.2
more-itertools==8.0.0
packaging==19.2
pluggy==0.13.1
py==1.10.0
pyparsing==2.4.5
pyrsistent==0.15.6
pytest==5.3.1
pytz==2019.3
six==1.13.0
wcwidth==0.1.7
Werkzeug==0.16.0
aniso8601==9.0.1
attrs==21.4.0
click==8.1.3
flask>=2,<3
flask-restx>=1,<2
flask-marshmallow==0.14.0
itsdangerous==2.1.2
jinja2==3.1.2
jsonschema==4.7.2
MarkupSafe==2.1.1
marshmallow==3.17.0
more-itertools==8.13.0
packaging==21.3
pluggy==1.0.0
py==1.11.0
pyparsing==3.0.9
pyrsistent==0.18.1
pytest==7.1.2
pytz==2022.1
six==1.16.0
wcwidth==0.2.5
werkzeug>=2,<3; python_version < '3.8'
werkzeug>=3,<4; python_version >= '3.8'
17 changes: 8 additions & 9 deletions examples/default_values.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
import datetime
from dataclasses import dataclass
from marshmallow import fields, Schema, post_load
from flask import Flask, jsonify, request
from flask_accepts import accepts, responds


class CogSchema(Schema):
cog_foo = fields.String(default="cog")
cog_baz = fields.Integer(default=999)
cog_foo = fields.String(dump_default="cog")
cog_baz = fields.Integer(dump_default=999)


class WidgetSchema(Schema):
foo = fields.String(default="test string")
baz = fields.Integer(default=42)
flag = fields.Bool(default=False)
date = fields.Date(default="01-01-1900")
dec = fields.Decimal(default=42.42)
dct = fields.Dict(default={"key": "value"})
foo = fields.String(dump_default="test string")
baz = fields.Integer(dump_default=42)
flag = fields.Bool(dump_default=False)
date = fields.Date(dump_default="01-01-1900")
dec = fields.Decimal(dump_default=42.42)
dct = fields.Dict(dump_default={"key": "value"})

cog = fields.Nested(CogSchema)

Expand Down
4 changes: 2 additions & 2 deletions examples/marshmallow_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ class Widget:


class WidgetSchema(Schema):
foo = fields.String(default="test value")
baz = fields.Integer(default=422)
foo = fields.String(dump_default="test value")
baz = fields.Integer(dump_default=422)

@post_load
def make(self, data, **kwargs):
Expand Down
16 changes: 8 additions & 8 deletions examples/nested_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,17 @@


class CogSchema(Schema):
cog_foo = fields.String(default="cog")
cog_baz = fields.Integer(default=999)
cog_foo = fields.String(dump_default="cog")
cog_baz = fields.Integer(dump_default=999)


class WidgetSchema(Schema):
foo = fields.String(default="test string")
baz = fields.Integer(default=42)
flag = fields.Bool(default=False)
date = fields.Date(default="01-01-1900")
dec = fields.Decimal(default=42.42)
dct = fields.Dict(default={"key": "value"})
foo = fields.String(dump_default="test string")
baz = fields.Integer(dump_default=42)
flag = fields.Bool(dump_default=False)
date = fields.Date(dump_default="01-01-1900")
dec = fields.Decimal(dump_default=42.42)
dct = fields.Dict(dump_default={"key": "value"})

cog = fields.Nested(CogSchema)

Expand Down
43 changes: 29 additions & 14 deletions flask_accepts/decorators/decorators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import OrderedDict
from typing import Type, Union
from typing import Type, Union, Dict
from flask import jsonify
from werkzeug.wrappers import Response
from werkzeug.exceptions import BadRequest, InternalServerError
Expand Down Expand Up @@ -126,7 +126,7 @@ def inner(*args, **kwargs):
# Handle Marshmallow schema for request body
if schema:
try:
obj = schema.load(request.get_json(force=True))
obj = schema.load(request.get_json(force=True) or {})
request.parsed_obj = obj
except ValidationError as ex:
schema_error = ex.messages
Expand All @@ -135,9 +135,9 @@ def inner(*args, **kwargs):
f"Error parsing request body: {schema_error}"
)
if hasattr(error, "data"):
error.data["errors"].update({"schema_errors": schema_error})
error.data["errors"].update(schema_error)
else:
error.data = {"schema_errors": schema_error}
error.data = {"errors": schema_error}

# Handle Marshmallow schema for query params
if query_params_schema:
Expand All @@ -155,9 +155,9 @@ def inner(*args, **kwargs):
f"Error parsing query params: {schema_error}"
)
if hasattr(error, "data"):
error.data["errors"].update({"schema_errors": schema_error})
error.data["errors"].update(schema_error)
else:
error.data = {"schema_errors": schema_error}
error.data = {"errors": schema_error}

# Handle Marshmallow schema for headers
if headers_schema:
Expand All @@ -175,9 +175,9 @@ def inner(*args, **kwargs):
f"Error parsing headers: {schema_error}"
)
if hasattr(error, "data"):
error.data["errors"].update({"schema_errors": schema_error})
error.data["errors"].update(schema_error)
else:
error.data = {"schema_errors": schema_error}
error.data = {"errors": schema_error}

# Handle Marshmallow schema for form data
if form_schema:
Expand All @@ -195,9 +195,9 @@ def inner(*args, **kwargs):
f"Error parsing form data: {schema_error}"
)
if hasattr(error, "data"):
error.data["errors"].update({"schema_errors": schema_error})
error.data["errors"].update(schema_error)
else:
error.data = {"schema_errors": schema_error}
error.data = {"errors": schema_error}

# If any parsing produced an error, combine them and re-raise
if error:
Expand Down Expand Up @@ -231,7 +231,8 @@ def inner(*args, **kwargs):
def responds(
*args,
model_name: str = None,
schema=None,
schema: Union[Schema, Type[Schema]] = None,
alt_schemas: Dict[int, Union[Schema, Type[Schema]]] = None,
many: bool = False,
api=None,
envelope=None,
Expand All @@ -250,6 +251,7 @@ def responds(
Args:
schema (bool, optional): Marshmallow schema with which to serialize the output
of the wrapped function.
alt_schemas (dict, optional): Dict of alternate schemas to use based on the status_code
many (bool, optional): (DEPRECATED) The Marshmallow schema `many` parameter, which will
return a list of the corresponding schema objects when set to True.
Expand Down Expand Up @@ -292,17 +294,29 @@ def decorator(func):

@wraps(func)
def inner(*args, **kwargs):
nonlocal schema
nonlocal status_code

rv = func(*args, **kwargs)

# If a Flask response has been made already, it is passed through unchanged
if isinstance(rv, Response):
return rv
if schema:
serialized = schema.dump(rv)

resp_schema = schema
# allow overriding the status code passed to Flask
if isinstance(rv, tuple):
rv, status_code = rv
if alt_schemas and status_code in alt_schemas:
# override the default response schema
resp_schema = _get_or_create_schema(alt_schemas[status_code])

if resp_schema:
serialized = resp_schema.dump(rv)

# Validate data if asked to (throws)
if validate:
errs = schema.validate(serialized)
errs = resp_schema.validate(serialized)
if errs:
raise InternalServerError(
description="Server attempted to return invalid data"
Expand Down Expand Up @@ -338,6 +352,7 @@ def remove_none(obj):
return jsonify(serialized), status_code
return serialized, status_code

nonlocal schema
# Add Swagger
if api and use_swagger and _IS_METHOD:
if schema:
Expand Down
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 6e808b8

Please sign in to comment.