Skip to content

Commit

Permalink
fix(flask): implement default schema serializer (#350)
Browse files Browse the repository at this point in the history
This corrects an issue that caused the Flask extension to use the incorrect serializer for encoding JSON
  • Loading branch information
cofin authored Jan 18, 2025
1 parent 1421e23 commit 13f4bde
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 65 deletions.
70 changes: 35 additions & 35 deletions advanced_alchemy/extensions/flask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,52 @@
Example:
Basic usage with synchronous SQLAlchemy:
```python
from flask import Flask
from advanced_alchemy.extensions.flask import (
AdvancedAlchemy,
SQLAlchemySyncConfig,
EngineConfig,
)
.. code-block:: python
app = Flask(__name__)
from flask import Flask
from advanced_alchemy.extensions.flask import (
AdvancedAlchemy,
SQLAlchemySyncConfig,
EngineConfig,
)
db_config = SQLAlchemySyncConfig(
engine_config=EngineConfig(url="sqlite:///db.sqlite3"),
create_all=True, # Create tables on startup
)
app = Flask(__name__)
db = AdvancedAlchemy(config=db_config)
db.init_app(app)
db_config = SQLAlchemySyncConfig(
engine_config=EngineConfig(url="sqlite:///db.sqlite3"),
create_all=True, # Create tables on startup
)
db = AdvancedAlchemy(config=db_config)
db.init_app(app)
# Get a session in your route
@app.route("/")
def index():
session = db.get_session()
# Use session...
```
# Get a session in your route
@app.route("/")
def index():
session = db.get_session()
# Use session...
Using async SQLAlchemy:
```python
from advanced_alchemy.extensions.flask import (
AdvancedAlchemy,
SQLAlchemyAsyncConfig,
)
.. code-block:: python
from advanced_alchemy.extensions.flask import (
AdvancedAlchemy,
SQLAlchemyAsyncConfig,
)
app = Flask(__name__)
app = Flask(__name__)
db_config = SQLAlchemyAsyncConfig(
engine_config=EngineConfig(
url="postgresql+asyncpg://user:pass@localhost/db"
),
create_all=True,
)
db_config = SQLAlchemyAsyncConfig(
engine_config=EngineConfig(
url="postgresql+asyncpg://user:pass@localhost/db"
),
create_all=True,
)
db = AdvancedAlchemy(config=db_config)
db.init_app(app)
```
db = AdvancedAlchemy(config=db_config)
db.init_app(app)
"""

from advanced_alchemy import base, exceptions, filters, mixins, operations, repository, service, types, utils
Expand Down
8 changes: 6 additions & 2 deletions advanced_alchemy/extensions/flask/alembic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
"""Alembic integration for Flask applications."""
"""Flask application integration for Alembic database migrations.
This module provides integration between Flask applications and Alembic
for managing database schema migrations.
"""

from __future__ import annotations

Expand All @@ -16,7 +20,7 @@


def get_sqlalchemy_extension(app: Flask) -> AdvancedAlchemy:
"""Retrieve Advanced Alchemy extension from the Flask application.
"""Retrieve the Advanced Alchemy extension instance from a Flask application.
Args:
app: The :class:`flask.Flask` application instance.
Expand Down
14 changes: 8 additions & 6 deletions advanced_alchemy/extensions/flask/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@

from click import echo
from flask import g, has_request_context
from litestar.serialization import decode_json, encode_json
from sqlalchemy.exc import OperationalError
from typing_extensions import Literal

from advanced_alchemy._serialization import decode_json, encode_json
from advanced_alchemy.base import metadata_registry
from advanced_alchemy.config import EngineConfig as _EngineConfig
from advanced_alchemy.config.asyncio import SQLAlchemyAsyncConfig as _SQLAlchemyAsyncConfig
from advanced_alchemy.config.sync import SQLAlchemySyncConfig as _SQLAlchemySyncConfig
from advanced_alchemy.exceptions import ImproperConfigurationError
from advanced_alchemy.service import schema_dump

if TYPE_CHECKING:
from typing import Any
Expand All @@ -30,7 +31,6 @@

from advanced_alchemy.utils.portals import Portal


__all__ = ("EngineConfig", "SQLAlchemyAsyncConfig", "SQLAlchemySyncConfig")

ConfigT = TypeVar("ConfigT", bound="Union[SQLAlchemySyncConfig, SQLAlchemyAsyncConfig]")
Expand All @@ -39,13 +39,16 @@
def serializer(value: Any) -> str:
"""Serialize JSON field values.
Calls the `:func:schema_dump` function to convert the value to a built-in before encoding.
Args:
value: Any JSON serializable value.
Returns:
str: JSON string representation of the value.
"""
return encode_json(value).decode("utf-8")

return encode_json(schema_dump(value))


@dataclass
Expand All @@ -63,10 +66,9 @@ class EngineConfig(_EngineConfig):

json_deserializer: Callable[[str], Any] = decode_json
"""For dialects that support the :class:`~sqlalchemy.types.JSON` datatype, this is a Python callable that will
convert a JSON string to a Python object. By default, this is set to Litestar's decode_json function."""
convert a JSON string to a Python object."""
json_serializer: Callable[[Any], str] = serializer
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON.
By default, Litestar's encode_json function is used."""
"""For dialects that support the JSON datatype, this is a Python callable that will render a given object as JSON."""


@dataclass
Expand Down
44 changes: 24 additions & 20 deletions advanced_alchemy/extensions/flask/service.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
"""Flask-specific service classes."""
"""Flask-specific service classes.
This module provides Flask-specific service mixins and utilities for integrating
with the Advanced Alchemy service layer.
"""

from __future__ import annotations

Expand All @@ -10,31 +14,31 @@


class FlaskServiceMixin:
"""Mixin to add Flask-specific functionality to services.
"""A mixin class that adds Flask-specific functionality to services.
Example:
.. code-block:: python
This mixin provides methods and utilities for handling Flask-specific operations
when working with service classes.
:param serializer: The serializer instance to use for data transformation
:type serializer: :class:`advanced_alchemy.extensions.flask.config.Serializer`
from advanced_alchemy.service import (
SQLAlchemyAsyncRepositoryService,
)
from advanced_alchemy.extensions.flask import (
FlaskServiceMixin,
)
Example:
-------
.. code-block:: python
class UserService(
FlaskServiceMixin,
SQLAlchemyAsyncRepositoryService[User],
):
class Repo(repository.SQLAlchemySyncRepository[User]):
model_type = User
from advanced_alchemy.service import (
SQLAlchemyAsyncRepositoryService,
)
from advanced_alchemy.extensions.flask import (
FlaskServiceMixin,
)
repository_type = Repo
def get_user_response(self, user_id: int) -> Response:
user = self.get(user_id)
return self.jsonify(user.dict())
class MyService(
FlaskServiceMixin, SQLAlchemyAsyncRepositoryService
):
pass
"""

def jsonify(
Expand Down
69 changes: 67 additions & 2 deletions tests/unit/test_extensions/test_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytest
from flask import Flask, Response
from msgspec import Struct
from pydantic import BaseModel
from sqlalchemy import String, select, text
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
Expand Down Expand Up @@ -47,6 +48,12 @@ class UserSchema(Struct):
name: str


class UserPydantic(BaseModel):
"""Test user pydantic model."""

name: str


class UserService(SQLAlchemySyncRepositoryService[User], FlaskServiceMixin):
"""Test user service."""

Expand Down Expand Up @@ -544,7 +551,7 @@ async def get_user() -> User | None:
extension.portal_provider.stop()


def test_sync_service_jsonify(setup_database: Path) -> None:
def test_sync_service_jsonify_msgspec(setup_database: Path) -> None:
app = Flask(__name__)

with app.test_client() as client:
Expand All @@ -571,7 +578,7 @@ def test_route() -> Response:
assert result.scalar_one().name == "service_test"


def test_async_service_jsonify(setup_database: Path) -> None:
def test_async_service_jsonify_msgspec(setup_database: Path) -> None:
app = Flask(__name__)

with app.test_client() as client:
Expand Down Expand Up @@ -599,3 +606,61 @@ def test_route() -> Response:
assert result
assert result.name == "async_service_test"
extension.portal_provider.stop()


def test_sync_service_jsonify_pydantic(setup_database: Path) -> None:
app = Flask(__name__)

with app.test_client() as client:
config = SQLAlchemySyncConfig(
connection_string=f"sqlite:///{setup_database}", metadata=metadata, commit_mode="autocommit"
)
extension = AdvancedAlchemy(config, app)

@app.route("/test", methods=["POST"])
def test_route() -> Response:
service = UserService(extension.get_sync_session())
user = service.create({"name": "test_sync_service_jsonify_pydantic"})
return service.jsonify(service.to_schema(user, schema_type=UserPydantic))

# Test successful response (should commit)
response = client.post("/test")
assert response.status_code == 200

# Verify the data was committed
session = extension.get_session()
assert isinstance(session, Session)
result = session.execute(select(User).where(User.name == "test_sync_service_jsonify_pydantic"))
assert result.scalar_one().name == "test_sync_service_jsonify_pydantic"


def test_async_service_jsonify_pydantic(setup_database: Path) -> None:
app = Flask(__name__)

with app.test_client() as client:
config = SQLAlchemyAsyncConfig(
connection_string=f"sqlite+aiosqlite:///{setup_database}", metadata=metadata, commit_mode="autocommit"
)
extension = AdvancedAlchemy(config, app)

@app.route("/test", methods=["POST"])
def test_route() -> Response:
service = AsyncUserService(extension.get_async_session())
user = extension.portal_provider.portal.call(
service.create, {"name": "test_async_service_jsonify_pydantic"}
)
return service.jsonify(service.to_schema(user, schema_type=UserPydantic))

# Test successful response (should commit)
response = client.post("/test")
assert response.status_code == 200

# Verify the data was committed
session = extension.get_session()
assert isinstance(session, AsyncSession)
result = extension.portal_provider.portal.call(
session.scalar, select(User).where(User.name == "test_async_service_jsonify_pydantic")
)
assert result
assert result.name == "test_async_service_jsonify_pydantic"
extension.portal_provider.stop()

0 comments on commit 13f4bde

Please sign in to comment.