Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump sqlmodel, sqlalchemy, pydantic, fastapi #455

Merged
merged 2 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 136 additions & 54 deletions backend/openapi-schema.yml

Large diffs are not rendered by default.

45 changes: 22 additions & 23 deletions backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,21 @@ authors = [
]

dependencies = [
"redis~=5.0.1",
"fastapi~=0.92.0",
"uvicorn[standard]~=0.20.0",
"sqlmodel~=0.0.11",
"alembic~=1.11.1",
"python-multipart~=0.0.6",
"filetype~=1.2.0",
"redis~=5.0",
"fastapi~=0.115",
"uvicorn[standard]~=0.20",
"sqlmodel~=0.0",
"alembic~=1.0",
"python-multipart~=0.0",
"filetype~=1.2",
"websockets~=10.4",
"python-magic~=0.4.27",
"python-magic~=0.4",
"transcribee-proto",
"python-frontmatter~=1.0.0",
"psycopg2~=2.9.9",
"prometheus-fastapi-instrumentator~=6.1.0",
"python-frontmatter~=1.0",
"psycopg2~=2.9",
"prometheus-fastapi-instrumentator~=6.1",
"pydantic~=2.2",
"pydantic-settings>=2.7",
]
requires-python = ">=3.11"
readme = "./README.md"
Expand All @@ -31,27 +33,24 @@ license = { text = "AGPL-3.0" }
[dependency-groups]
dev = [
"pyyaml~=6.0",
"pytest~=7.3.1",
"httpx~=0.24.0",
"pytest-alembic~=0.10.4",
"pyright~=1.1.314",
"pytest~=7.3",
"httpx~=0.24",
"pytest-alembic~=0.10",
"pyright~=1.1",
]
notebooks = [
"jupyter~=1.0.0",
"pandas~=2.0.1",
"tabulate~=0.9.0",
"matplotlib~=3.7.1",
"seaborn~=0.12.2",
"jupyter~=1.0",
"pandas~=2.0",
"tabulate~=0.9",
"matplotlib~=3.7",
"seaborn~=0.12",
]

[project.scripts]
transcribee-migrate = "transcribee_backend.db.run_migrations:main"
transcribee-admin = "transcribee_backend.admin_cli:main"

[tool.uv]
override-dependencies = [
"sqlalchemy==1.4.41"
]
config-settings = { editable_mode = "compat" }

[tool.uv.sources]
Expand Down
22 changes: 15 additions & 7 deletions backend/tests/test_doc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import uuid

import pytest
from fastapi.testclient import TestClient
from sqlmodel import Session
from sqlmodel import Session, col, func, select
from transcribee_backend.auth import generate_share_token
from transcribee_backend.config import settings
from transcribee_backend.models import (
Expand All @@ -23,7 +25,7 @@ def document_id(memory_session: Session, logged_in_client: TestClient):
data={"name": "test document", "model": "tiny", "language": "auto"},
)
assert req.status_code == 200
document_id = req.json()["id"]
document_id = uuid.UUID(req.json()["id"])

memory_session.add(DocumentUpdate(document_id=document_id, change_bytes=b""))
memory_session.commit()
Expand All @@ -50,7 +52,7 @@ def test_doc_delete(
]
counts = {}
for table in checked_tables:
counts[table] = memory_session.query(table).count()
counts[table] = memory_session.exec(select(func.count(col(table.id)))).one()

files = set(str(x) for x in settings.storage_path.glob("*"))

Expand All @@ -60,14 +62,15 @@ def test_doc_delete(
data={"name": "test document", "model": "tiny", "language": "auto"},
)
assert req.status_code == 200
document_id = req.json()["id"]
document_id = uuid.UUID(req.json()["id"])

req = logged_in_client.get(f"/api/v1/documents/{document_id}/tasks/")
task_id = uuid.UUID(req.json()[0]["id"])
assert req.status_code == 200
assert len(req.json()) >= 1

memory_session.add(DocumentUpdate(document_id=document_id, change_bytes=b""))
memory_session.add(TaskAttempt(task_id=req.json()[0]["id"], attempt_number=1))
memory_session.add(TaskAttempt(task_id=task_id, attempt_number=1))
memory_session.add(
generate_share_token(
document_id=document_id, name="Test Token", valid_until=None, can_write=True
Expand All @@ -76,7 +79,9 @@ def test_doc_delete(
memory_session.commit()

for table in checked_tables:
assert counts[table] < memory_session.query(table).count()
assert (
counts[table] < memory_session.exec(select(func.count(col(table.id)))).one()
)

assert files < set(str(x) for x in settings.storage_path.glob("*"))

Expand All @@ -87,7 +92,10 @@ def test_doc_delete(
assert req.status_code == 200

for table in checked_tables:
assert counts[table] == memory_session.query(table).count()
assert (
counts[table]
== memory_session.exec(select(func.count(col(table.id)))).one()
)

assert files == set(str(x) for x in settings.storage_path.glob("*"))

Expand Down
2 changes: 1 addition & 1 deletion backend/transcribee_backend/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def validate_user_authorization(session: Session, authorization: str):
raise HTTPException(status_code=400, detail="Invalid Token")
user_id, provided_token = token_data.split(":", maxsplit=1)
statement = select(UserToken).where(
UserToken.user_id == user_id, UserToken.valid_until >= now_tz_aware()
UserToken.user_id == uuid.UUID(user_id), UserToken.valid_until >= now_tz_aware()
)
results = session.exec(statement)
for token in results:
Expand Down
31 changes: 17 additions & 14 deletions backend/transcribee_backend/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,32 @@
from typing import Dict, List, Optional

import frontmatter
from pydantic import BaseModel, BaseSettings, parse_file_as, parse_obj_as
from pydantic import BaseModel, TypeAdapter
from pydantic_settings import BaseSettings

pages = None


class Settings(BaseSettings):
storage_path: Path = Path("storage/")
secret_key = "insecure-secret-key"
worker_timeout = 60 # in seconds
media_signature_max_age = 3600 # in seconds
task_attempt_limit = 5
secret_key: str = "insecure-secret-key"
worker_timeout: int = 60 # in seconds
media_signature_max_age: int = 3600 # in seconds
task_attempt_limit: int = 5

media_url_base = "http://localhost:8000/"
media_url_base: str = "http://localhost:8000/"
logged_out_redirect_url: None | str = None

model_config_path: Path = Path(__file__).parent.resolve() / Path(
"default_models.json"
)
pages_dir: Path = Path("data/pages/")

metrics_username = "transcribee"
metrics_password = "transcribee"
metrics_username: str = "transcribee"
metrics_password: str = "transcribee"

redis_host = "localhost"
redis_port = 6379
redis_host: str = "localhost"
redis_port: int = 6379


class ModelConfig(BaseModel):
Expand All @@ -37,20 +38,22 @@ class ModelConfig(BaseModel):

class PublicConfig(BaseModel):
models: Dict[str, ModelConfig]
logged_out_redirect_url: str | None
logged_out_redirect_url: str | None = None


class ShortPageConfig(BaseModel):
name: str
footer_position: Optional[int]
footer_position: Optional[int] = None


class PageConfig(ShortPageConfig):
text: str


def get_model_config():
return parse_file_as(Dict[str, ModelConfig], settings.model_config_path)
return TypeAdapter(Dict[str, ModelConfig]).validate_json(
Path(settings.model_config_path).read_text()
)


def load_pages_from_disk() -> Dict[str, PageConfig]:
Expand All @@ -75,7 +78,7 @@ def get_page_config():


def get_short_page_config() -> Dict[str, ShortPageConfig]:
return parse_obj_as(Dict[str, ShortPageConfig], get_page_config())
return TypeAdapter(Dict[str, ShortPageConfig]).validate_python(get_page_config())


def get_public_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def upgrade() -> None:
op.create_table(
"user",
sa.Column("username", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("password_hash", sa.LargeBinary(), nullable=False),
sa.Column("password_salt", sa.LargeBinary(), nullable=False),
sa.PrimaryKeyConstraint("id"),
Expand All @@ -35,7 +35,7 @@ def upgrade() -> None:
"worker",
sa.Column("last_seen", sa.DateTime(timezone=True), nullable=True),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("token", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint("id"),
)
Expand All @@ -47,8 +47,8 @@ def upgrade() -> None:
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("changed_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("user_id", sa.Uuid, nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
Expand All @@ -61,8 +61,8 @@ def upgrade() -> None:
op.create_table(
"usertoken",
sa.Column("valid_until", sa.DateTime(timezone=True), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("user_id", sa.Uuid, nullable=False),
sa.Column("token_hash", sa.LargeBinary(), nullable=False),
sa.Column("token_salt", sa.LargeBinary(), nullable=False),
sa.ForeignKeyConstraint(
Expand All @@ -78,8 +78,8 @@ def upgrade() -> None:
"documentmediafile",
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("changed_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("document_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("document_id", sa.Uuid, nullable=False),
sa.Column("file", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("content_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.ForeignKeyConstraint(
Expand All @@ -96,8 +96,8 @@ def upgrade() -> None:
op.create_table(
"documentupdate",
sa.Column("change_bytes", sa.LargeBinary(), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("document_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("document_id", sa.Uuid, nullable=False),
sa.ForeignKeyConstraint(
["document_id"],
["document.id"],
Expand All @@ -111,10 +111,10 @@ def upgrade() -> None:
"task",
sa.Column("task_parameters", sa.JSON(), nullable=False),
sa.Column("task_type", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("document_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("document_id", sa.Uuid, nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("progress", sa.Float(), nullable=True),
sa.Column("assigned_worker_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("assigned_worker_id", sa.Uuid, nullable=True),
sa.Column("assigned_at", sa.DateTime(), nullable=True),
sa.Column("last_keepalive", sa.DateTime(), nullable=True),
sa.Column("is_completed", sa.Boolean(), nullable=False),
Expand All @@ -134,9 +134,9 @@ def upgrade() -> None:

op.create_table(
"documentmediatag",
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("tag", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("media_file_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("media_file_id", sa.Uuid, nullable=False),
sa.ForeignKeyConstraint(
["media_file_id"],
["documentmediafile.id"],
Expand All @@ -150,9 +150,9 @@ def upgrade() -> None:

op.create_table(
"taskdependency",
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("dependent_task_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("dependant_on_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("dependent_task_id", sa.Uuid, nullable=False),
sa.Column("dependant_on_id", sa.Uuid, nullable=False),
sa.ForeignKeyConstraint(
["dependant_on_id"],
["task.id"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ def upgrade_with_autocommit() -> None:
TaskAttempt = op.create_table(
"taskattempt",
sa.Column("extra_data", sa.JSON(), nullable=True),
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("task_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("assigned_worker_id", sqlmodel.sql.sqltypes.GUID(), nullable=True),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("task_id", sa.Uuid, nullable=False),
sa.Column("assigned_worker_id", sa.Uuid, nullable=True),
sa.Column("attempt_number", sa.Integer(), nullable=False),
sa.Column("started_at", sa.DateTime(), nullable=True),
sa.Column("last_keepalive", sa.DateTime(), nullable=True),
Expand All @@ -51,9 +51,7 @@ def upgrade_with_autocommit() -> None:
batch_op.create_index(batch_op.f("ix_taskattempt_id"), ["id"], unique=False)

with op.batch_alter_table("task", schema=None) as batch_op:
batch_op.add_column(
sa.Column("current_attempt_id", sqlmodel.sql.sqltypes.GUID(), nullable=True)
)
batch_op.add_column(sa.Column("current_attempt_id", sa.Uuid, nullable=True))
batch_op.add_column(
sa.Column(
"attempt_counter", sa.Integer(), nullable=True, server_default="0"
Expand Down Expand Up @@ -90,20 +88,20 @@ def upgrade_with_autocommit() -> None:

Task = sa.table(
"task",
sa.column("id", sqlmodel.sql.sqltypes.GUID()),
sa.column("id", sa.Uuid),
sa.column("assigned_at", sa.DateTime()),
sa.column("last_keepalive", sa.DateTime()),
sa.column("completed_at", sa.DateTime()),
sa.column("is_completed", sa.Boolean()),
sa.column("completion_data", sa.JSON()),
sa.column("assigned_worker_id", sqlmodel.sql.sqltypes.GUID()),
sa.column("assigned_worker_id", sa.Uuid),
sa.column("state_changed_at", sa.DateTime()),
sa.column(
"state",
sa.Enum("NEW", "ASSIGNED", "COMPLETED", "FAILED", name="taskstate"),
),
sa.column("remaining_attempts", sa.Integer()),
sa.column("current_attempt_id", sqlmodel.sql.sqltypes.GUID()),
sa.column("current_attempt_id", sa.Uuid),
sa.column("progress", sa.Float()),
sa.column("attempt_counter", sa.Integer()),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"apitoken",
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
sa.Column("id", sa.Uuid, nullable=False),
sa.Column("name", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.Column("token", sqlmodel.sql.sqltypes.AutoString(), nullable=False),
sa.PrimaryKeyConstraint("id"),
Expand Down
Loading
Loading