Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: Update SQLAlchemy to 2.0 #68

Merged
merged 1 commit into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions plugin_store/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from fastapi import Depends
from sqlalchemy import asc, desc
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession, create_async_engine
from sqlalchemy.sql import delete, select, update

from constants import SortDirection, SortType
Expand All @@ -21,21 +20,21 @@
from .models.Version import Version

if TYPE_CHECKING:
from typing import AsyncIterator, Iterable
from typing import AsyncIterator, Iterable, Sequence

logger = logging.getLogger()

UTC = ZoneInfo("UTC")


db_url = getenv("DB_URL")
if not db_url:
raise Exception("DB_URL not provided or invalid!")
async_engine = create_async_engine(
getenv("DB_URL"),
db_url,
pool_pre_ping=True,
# echo=settings.ECHO_SQL,
)
AsyncSessionLocal = sessionmaker(
bind=async_engine, autoflush=False, future=True, expire_on_commit=False, class_=AsyncSession
)
AsyncSessionLocal = async_sessionmaker(bind=async_engine, autoflush=False, future=True, expire_on_commit=False)

db_lock = Lock()

Expand Down Expand Up @@ -158,7 +157,7 @@ async def search(
sort_direction: SortDirection = SortDirection.DESC,
limit: int = 50,
page: int = 0,
) -> list["Artifact"]:
) -> "Sequence[Artifact]":
statement = select(Artifact).offset(limit * page)
if name:
statement = statement.where(Artifact.name.like(f"%{name}%"))
Expand Down
5 changes: 4 additions & 1 deletion plugin_store/database/migrations/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,11 @@ async def run_migrations_online() -> None:
and associate a connection with the context.
"""
db_url = getenv("DB_URL")
if not db_url:
raise Exception("DB_URL not provided or invalid!")
connectable = create_async_engine(
getenv("DB_URL"),
db_url,
poolclass=pool.NullPool,
future=True,
)
Expand Down
27 changes: 14 additions & 13 deletions plugin_store/database/models/Artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from urllib.parse import quote

from sqlalchemy import Boolean, Column, ForeignKey, func, Integer, select, Table, Text, UniqueConstraint
from sqlalchemy.orm import column_property, relationship
from sqlalchemy.orm import column_property, Mapped, relationship

import constants

Expand All @@ -29,30 +29,31 @@ class Tag(Base):
class Artifact(Base):
__tablename__ = "artifacts"

id: int = Column(Integer, autoincrement=True, primary_key=True)
name: str = Column(Text)
author: str = Column(Text)
description: str = Column(Text)
_image_path: "str | None" = Column("image_path", Text, nullable=True)
tags: "list[Tag]" = relationship(
id: Mapped[int] = Column(Integer, autoincrement=True, primary_key=True)
name: Mapped[str] = Column(Text)
author: Mapped[str] = Column(Text)
description: Mapped[str] = Column(Text)
_image_path: Mapped[str | None] = Column("image_path", Text, nullable=True)
tags: "Mapped[list[Tag]]" = relationship(
"Tag", secondary=PluginTag, cascade="all, delete", order_by="Tag.tag", lazy="selectin"
)
versions: "list[Version]" = relationship(
versions: "Mapped[list[Version]]" = relationship(
"Version", cascade="all, delete", lazy="selectin", order_by="Version.created.desc(), Version.id.asc()"
)
visible: bool = Column(Boolean, default=True)
visible: Mapped[bool] = Column(Boolean, default=True)

downloads: int = column_property(
# Properties computed from relations
downloads: Mapped[int] = column_property(
select(func.sum(Version.downloads)).where(Version.artifact_id == id).correlate_except(Version).scalar_subquery()
)
updates: int = column_property(
updates: Mapped[int] = column_property(
select(func.sum(Version.updates)).where(Version.artifact_id == id).correlate_except(Version).scalar_subquery()
)

created: datetime = column_property(
created: Mapped[datetime] = column_property(
select(func.min(Version.created)).where(Version.artifact_id == id).correlate_except(Version).scalar_subquery()
)
updated: datetime = column_property(
updated: Mapped[datetime] = column_property(
select(func.max(Version.created)).where(Version.artifact_id == id).correlate_except(Version).scalar_subquery()
)

Expand Down
6 changes: 4 additions & 2 deletions plugin_store/database/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
from sqlalchemy.types import TypeDecorator

if TYPE_CHECKING:
from typing import Any

from sqlalchemy.engine import Dialect

UTC = ZoneInfo("UTC")


class TZDateTime(TypeDecorator):
class TZDateTime(TypeDecorator[datetime]):
"""
A DateTime type which can only store tz-aware DateTimes.
"""
Expand All @@ -26,7 +28,7 @@ def process_bind_param(self, value: "datetime | None", dialect: "Dialect"):
return value.astimezone(UTC)
return value

def process_result_value(self, value: "datetime", dialect: "Dialect") -> "datetime | None":
def process_result_value(self, value: "Any | None", dialect: "Dialect") -> "datetime | None":
if isinstance(value, datetime) and value.tzinfo is None:
return value.replace(tzinfo=UTC)
return value
Expand Down
Loading
Loading