From 39055c2c3521aa6928a8afebb2131e441ddae567 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Wed, 10 Jul 2024 16:12:26 +0200 Subject: [PATCH 01/24] refactor: Clean up DB sessions consistently --- .pre-commit-config.yaml | 1 - environment.yml | 2 +- init_db.py | 7 ++--- .../quetz_conda_suggest/main.py | 4 +-- .../tests/test_quetz_conda_suggest.py | 11 +++----- .../quetz_content_trust/api.py | 8 +++--- .../quetz_content_trust/main.py | 4 +-- .../quetz_repodata_patching/main.py | 4 +-- .../tests/test_main.py | 8 +++--- .../quetz_runexports/quetz_runexports/main.py | 4 +-- .../tests/test_quetz_runexports.py | 4 +-- pyproject.toml | 3 --- quetz/cli.py | 26 +++++++++---------- quetz/database.py | 22 +++++----------- quetz/deps.py | 13 +--------- quetz/testing/fixtures.py | 5 ++-- quetz/tests/test_dao.py | 9 ++++--- 17 files changed, 53 insertions(+), 82 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d364db7a..3a0c58db 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,7 +15,6 @@ repos: - id: mypy files: ^quetz/ additional_dependencies: - - sqlalchemy-stubs - types-click - types-Jinja2 - types-mock diff --git a/environment.yml b/environment.yml index 7c2e78c7..a7bb9189 100644 --- a/environment.yml +++ b/environment.yml @@ -9,7 +9,7 @@ dependencies: - authlib=0.15.5 - psycopg2 - httpx>=0.22.0 - - sqlalchemy + - sqlalchemy >= 2, <3 - sqlalchemy-utils - sqlite - python-multipart diff --git a/init_db.py b/init_db.py index e24dd8f6..b01c419f 100644 --- a/init_db.py +++ b/init_db.py @@ -21,11 +21,10 @@ def init_test_db(): config = Config() init_db(config.sqlalchemy_database_url) - db = get_session(config.sqlalchemy_database_url) - testUsers = [] + with get_session(config) as db: + testUsers = [] - try: for index, username in enumerate(["alice", "bob", "carol", "dave"]): user = User(id=uuid.uuid4().bytes, username=username) @@ -102,8 +101,6 @@ def init_test_db(): db.add(channel_member) db.commit() - finally: - db.close() if __name__ == "__main__": diff --git a/plugins/quetz_conda_suggest/quetz_conda_suggest/main.py b/plugins/quetz_conda_suggest/quetz_conda_suggest/main.py index 7f21c095..d6095f21 100644 --- a/plugins/quetz_conda_suggest/quetz_conda_suggest/main.py +++ b/plugins/quetz_conda_suggest/quetz_conda_suggest/main.py @@ -4,7 +4,7 @@ import quetz from quetz.config import Config -from quetz.database import get_db_manager +from quetz.database import get_session from quetz.db_models import PackageVersion from quetz.utils import add_entry_for_index @@ -45,7 +45,7 @@ def post_add_package_version(version, condainfo): if command not in suggest_map: suggest_map[command] = package - with get_db_manager() as db: + with get_session() as db: if not version.binfiles: metadata = db_models.CondaSuggestMetadata( version_id=version.id, data=json.dumps(suggest_map) diff --git a/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py b/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py index 89746340..d22c9388 100644 --- a/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py +++ b/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py @@ -2,7 +2,6 @@ import shutil import tarfile import tempfile -from contextlib import contextmanager from unittest import mock import pytest @@ -27,13 +26,12 @@ def test_post_add_package_version(package_version, db, config): target.seek(0) condainfo = CondaInfo(target, filename) - @contextmanager def get_db(): - yield db + return db from quetz_conda_suggest import main - with mock.patch("quetz_conda_suggest.main.get_db_manager", get_db): + with mock.patch("quetz_conda_suggest.main.get_session", get_db): main.post_add_package_version(package_version, condainfo) meta = db.query(db_models.CondaSuggestMetadata).first() @@ -50,7 +48,7 @@ def get_db(): b"lib/libtpkg.so\n", b"lib/pkgconfig/libtpkg.pc\n", ] - with mock.patch("quetz_conda_suggest.main.get_db_manager", get_db): + with mock.patch("quetz_conda_suggest.main.get_session", get_db): main.post_add_package_version(package_version, condainfo) meta = db.query(db_models.CondaSuggestMetadata).all() @@ -76,7 +74,6 @@ def test_conda_suggest_endpoint_with_upload( response = client.get("/api/dummylogin/madhurt") filename = "test-package-0.1-0.tar.bz2" - @contextmanager def get_db(): yield db @@ -114,7 +111,7 @@ def get_db(): tar.addfile(t, io.BytesIO(b)) tar.close() - with mock.patch("quetz_conda_suggest.main.get_db_manager", get_db): + with mock.patch("quetz_conda_suggest.main.get_session", get_db): url = f"/api/channels/{channel.name}/files/" files = {"files": (filename, open(filename, "rb"))} response = client.post(url, files=files) diff --git a/plugins/quetz_content_trust/quetz_content_trust/api.py b/plugins/quetz_content_trust/quetz_content_trust/api.py index 5f55dbef..14a05e0f 100644 --- a/plugins/quetz_content_trust/quetz_content_trust/api.py +++ b/plugins/quetz_content_trust/quetz_content_trust/api.py @@ -8,7 +8,7 @@ from quetz import authorization from quetz.config import Config -from quetz.database import get_db_manager +from quetz.database import get_session from quetz.deps import get_rules from . import db_models @@ -102,7 +102,7 @@ def post_role( ): auth.assert_channel_roles(channel, ["owner"]) - with get_db_manager() as db: + with get_session() as db: existing_role_count = ( db.query(db_models.ContentTrustRole) .filter( @@ -190,7 +190,7 @@ def get_role( ): auth.assert_channel_roles(channel, ["owner", "maintainer", "member"]) - with get_db_manager() as db: + with get_session() as db: query = ( db.query(db_models.ContentTrustRole) .filter(db_models.ContentTrustRole.channel == channel) @@ -211,7 +211,7 @@ def get_new_key(secret: bool = False): mamba_key = libmamba_api.Key.from_ed25519(key.public_key) private_key = key.private_key - with get_db_manager() as db: + with get_session() as db: db.add(key) db.commit() diff --git a/plugins/quetz_content_trust/quetz_content_trust/main.py b/plugins/quetz_content_trust/quetz_content_trust/main.py index 5f3723b9..f2f4a996 100644 --- a/plugins/quetz_content_trust/quetz_content_trust/main.py +++ b/plugins/quetz_content_trust/quetz_content_trust/main.py @@ -4,7 +4,7 @@ from sqlalchemy import desc import quetz -from quetz.database import get_db_manager +from quetz.database import get_session from . import db_models from .api import router @@ -21,7 +21,7 @@ def register_router(): def post_index_creation(raw_repodata: dict, channel_name, subdir): """Use available online keys to sign packages""" - with get_db_manager() as db: + with get_session() as db: query = ( db.query(db_models.SigningKey) .join(db_models.RoleDelegation.keys) diff --git a/plugins/quetz_repodata_patching/quetz_repodata_patching/main.py b/plugins/quetz_repodata_patching/quetz_repodata_patching/main.py index 4ed2e43f..57ecca1b 100644 --- a/plugins/quetz_repodata_patching/quetz_repodata_patching/main.py +++ b/plugins/quetz_repodata_patching/quetz_repodata_patching/main.py @@ -9,7 +9,7 @@ import quetz from quetz.config import Config -from quetz.database import get_db_manager +from quetz.database import get_session from quetz.db_models import PackageFormatEnum, PackageVersion from quetz.utils import add_temp_static_file @@ -107,7 +107,7 @@ def _load_instructions(tar, path): @quetz.hookimpl(tryfirst=True) def post_package_indexing(tempdir: Path, channel_name, subdirs, files, packages): - with get_db_manager() as db: + with get_session() as db: query = ( db.query(PackageVersion) .filter( diff --git a/plugins/quetz_repodata_patching/tests/test_main.py b/plugins/quetz_repodata_patching/tests/test_main.py index defb3261..dce0428d 100644 --- a/plugins/quetz_repodata_patching/tests/test_main.py +++ b/plugins/quetz_repodata_patching/tests/test_main.py @@ -300,7 +300,7 @@ def test_post_package_indexing( def get_db(): yield db - with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db): + with mock.patch("quetz_repodata_patching.main.get_session", get_db): indexing.update_indexes(dao, pkgstore, channel_name) ext = "json.bz2" if compressed_repodata else "json" @@ -378,7 +378,7 @@ def test_index_html( def get_db(): yield db - with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db): + with mock.patch("quetz_repodata_patching.main.get_session", get_db): indexing.update_indexes(dao, pkgstore, channel_name) index_path = os.path.join( @@ -419,7 +419,7 @@ def test_patches_for_subdir( def get_db(): yield db - with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db): + with mock.patch("quetz_repodata_patching.main.get_session", get_db): indexing.update_indexes(dao, pkgstore, channel_name) index_path = os.path.join( @@ -471,7 +471,7 @@ def test_no_repodata_patches_package( def get_db(): yield db - with mock.patch("quetz_repodata_patching.main.get_db_manager", get_db): + with mock.patch("quetz_repodata_patching.main.get_session", get_db): indexing.update_indexes(dao, pkgstore, channel_name) index_path = os.path.join( diff --git a/plugins/quetz_runexports/quetz_runexports/main.py b/plugins/quetz_runexports/quetz_runexports/main.py index 4a580b81..ed72fa8e 100644 --- a/plugins/quetz_runexports/quetz_runexports/main.py +++ b/plugins/quetz_runexports/quetz_runexports/main.py @@ -1,7 +1,7 @@ import json import quetz -from quetz.database import get_db_manager +from quetz.database import get_session from . import db_models from .api import router @@ -16,7 +16,7 @@ def register_router(): def post_add_package_version(version, condainfo): run_exports = json.dumps(condainfo.run_exports) - with get_db_manager() as db: + with get_session() as db: if not version.runexports: metadata = db_models.PackageVersionMetadata( version_id=version.id, data=run_exports diff --git a/plugins/quetz_runexports/tests/test_quetz_runexports.py b/plugins/quetz_runexports/tests/test_quetz_runexports.py index 7a02d94f..178eb191 100644 --- a/plugins/quetz_runexports/tests/test_quetz_runexports.py +++ b/plugins/quetz_runexports/tests/test_quetz_runexports.py @@ -62,7 +62,7 @@ def get_db(): from quetz_runexports import main - with mock.patch("quetz_runexports.main.get_db_manager", get_db): + with mock.patch("quetz_runexports.main.get_session", get_db): main.post_add_package_version(package_version, condainfo) meta = db.query(db_models.PackageVersionMetadata).first() @@ -71,7 +71,7 @@ def get_db(): # modify runexport and re-save condainfo.run_exports = {"weak": ["somepackage < 0.3"]} - with mock.patch("quetz_runexports.main.get_db_manager", get_db): + with mock.patch("quetz_runexports.main.get_session", get_db): main.post_add_package_version(package_version, condainfo) meta = db.query(db_models.PackageVersionMetadata).all() diff --git a/pyproject.toml b/pyproject.toml index 2e64b5b7..dd674db8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,9 +59,6 @@ venvPath= "." [tool.mypy] ignore_missing_imports = true -plugins = [ - "sqlmypy" -] disable_error_code = [ "misc" ] diff --git a/quetz/cli.py b/quetz/cli.py index bff53c68..34e313cf 100644 --- a/quetz/cli.py +++ b/quetz/cli.py @@ -359,8 +359,8 @@ def add_user_roles( config = _get_config(path) with working_directory(path): - db = get_session(config.sqlalchemy_database_url) - _set_user_roles(db, config) + with get_session(config) as db: + _set_user_roles(db, config) @app.command() @@ -500,11 +500,11 @@ def create( deployment_folder.joinpath("channels").mkdir(exist_ok=True) with working_directory(db_path): - db = get_session(config.sqlalchemy_database_url) _run_migrations(config.sqlalchemy_database_url) - if dev: - _fill_test_database(db) - _set_user_roles(db, config) + with get_session(config) as db: + if dev: + _fill_test_database(db) + _set_user_roles(db, config) def _get_config(path: Union[Path, str]) -> Config: @@ -758,14 +758,12 @@ def start_supervisor_daemon(path: Path, num_procs=None): # is set there (it only matters for sqlite database). db_path = path if path.joinpath("config.toml").exists() else os.getcwd() with working_directory(db_path): - db = get_session(config.sqlalchemy_database_url) - supervisor = Supervisor(db, manager) - try: - supervisor.run() - except KeyboardInterrupt: - logger.info("stopping supervisor") - finally: - db.close() + with get_session(config) as db: + supervisor = Supervisor(db, manager) + try: + supervisor.run() + except KeyboardInterrupt: + logger.info("stopping supervisor") @app.command() diff --git a/quetz/database.py b/quetz/database.py index 9e77db65..1a9cb97b 100644 --- a/quetz/database.py +++ b/quetz/database.py @@ -2,7 +2,6 @@ # Distributed under the terms of the Modified BSD License. import logging import re -from contextlib import contextmanager from typing import Callable from sqlalchemy import create_engine, event @@ -66,19 +65,16 @@ def get_session_maker(engine) -> Callable[[], Session]: return sessionmaker(autocommit=False, autoflush=True, bind=engine) -def get_session(db_url: str, **kwargs) -> Session: +def get_session(config: Config | None) -> Session: """Get a database session. - - Important note: this function is mocked during tests! + ea + Important note: this function is mocked during tests! """ - return get_session_maker(get_engine(db_url, **kwargs))() - + if config is None: + config = Config() -@contextmanager -def get_db_manager(): - config = Config() - db = get_session( + engine = get_engine( db_url=config.sqlalchemy_database_url, echo=config.sqlalchemy_echo_sql, postgres_kwargs=dict( @@ -86,11 +82,7 @@ def get_db_manager(): max_overflow=config.sqlalchemy_postgres_max_overflow, ), ) - - try: - yield db - finally: - db.close() + return get_session_maker(engine)() def sanitize_db_url(db_url: str) -> str: diff --git a/quetz/deps.py b/quetz/deps.py index 3ef97665..495333fb 100644 --- a/quetz/deps.py +++ b/quetz/deps.py @@ -43,19 +43,8 @@ def get_config(): def get_db(config: Config = Depends(get_config)): - database_url = config.sqlalchemy_database_url - db = get_db_session( - database_url, - echo=config.sqlalchemy_echo_sql, - postgres_kwargs=dict( - pool_size=config.sqlalchemy_postgres_pool_size, - max_overflow=config.sqlalchemy_postgres_max_overflow, - ), - ) - try: + with get_db_session(config) as db: yield db - finally: - db.close() def get_dao(db: Session = Depends(get_db)): diff --git a/quetz/testing/fixtures.py b/quetz/testing/fixtures.py index 80671b39..32a4f4f5 100644 --- a/quetz/testing/fixtures.py +++ b/quetz/testing/fixtures.py @@ -1,7 +1,7 @@ import os import shutil import tempfile -from typing import List +from typing import List, Iterator import pytest from alembic.command import upgrade as alembic_upgrade @@ -13,6 +13,7 @@ from quetz.dao import Dao from quetz.database import get_engine, get_session_maker from quetz.db_models import Base +from sqlalchemy.orm import Session def pytest_configure(config): @@ -118,7 +119,7 @@ def auto_rollback(): @pytest.fixture -def session_maker(sql_connection, create_tables, auto_rollback): +def session_maker(sql_connection, create_tables, auto_rollback) -> Iterator[Session]: # run the tests with a separate external DB transaction # so that we can easily rollback all db changes (even if committed) # done by the test client diff --git a/quetz/tests/test_dao.py b/quetz/tests/test_dao.py index 0d1591ae..5ccaef7c 100644 --- a/quetz/tests/test_dao.py +++ b/quetz/tests/test_dao.py @@ -7,7 +7,7 @@ from quetz import errors, rest_models from quetz.dao import Dao -from quetz.database import get_session +from quetz.database import get_engine, get_session_maker from quetz.db_models import Channel, Package, PackageVersion from quetz.metrics.db_models import IntervalType, PackageVersionMetric, round_timestamp @@ -406,10 +406,11 @@ def db_extra(database_url): Use only for tests that require two sessions concurrently. For most cases you will want to use the db fixture (from quetz.testing.fixtures)""" - session = get_session(database_url) - + engine = get_engine( + db_url=database_url, + ) + session = get_session_maker(engine)() yield session - session.close() From 82c5adaa4dd38be2e5821f30ddb10368fa72cf12 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Wed, 10 Jul 2024 16:26:42 +0200 Subject: [PATCH 02/24] fix --- quetz/tasks/workers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quetz/tasks/workers.py b/quetz/tasks/workers.py index 865bb812..98dbacd3 100644 --- a/quetz/tasks/workers.py +++ b/quetz/tasks/workers.py @@ -148,7 +148,7 @@ def job_wrapper( db = dao.db close_session = False else: - db = get_session(config.sqlalchemy_database_url) + db = get_session(config) close_session = True user_id: Optional[str] From 40743fe2dfc01b72d2ef76fc7053a465a9f75e33 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Thu, 11 Jul 2024 11:29:22 +0200 Subject: [PATCH 03/24] wip --- quetz/database.py | 7 +++++-- quetz/testing/fixtures.py | 35 +++++++++++++++++++++++------------ quetz/tests/test_cli.py | 21 +++++++++++++-------- 3 files changed, 41 insertions(+), 22 deletions(-) diff --git a/quetz/database.py b/quetz/database.py index 1a9cb97b..27cc1c9e 100644 --- a/quetz/database.py +++ b/quetz/database.py @@ -4,6 +4,7 @@ import re from typing import Callable +import sqlalchemy from sqlalchemy import create_engine, event from sqlalchemy.engine import Engine from sqlalchemy.engine.url import make_url @@ -61,8 +62,10 @@ def on_close(dbapi_conn, conn_record): return engine -def get_session_maker(engine) -> Callable[[], Session]: - return sessionmaker(autocommit=False, autoflush=True, bind=engine) +def get_session_maker( + bind: sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, +) -> Callable[[], Session]: + return sessionmaker(autocommit=False, autoflush=True, bind=bin) def get_session(config: Config | None) -> Session: diff --git a/quetz/testing/fixtures.py b/quetz/testing/fixtures.py index 32a4f4f5..c03852c1 100644 --- a/quetz/testing/fixtures.py +++ b/quetz/testing/fixtures.py @@ -1,9 +1,10 @@ import os import shutil import tempfile -from typing import List, Iterator +from typing import List, Iterator, Callable import pytest +import sqlalchemy.orm from alembic.command import upgrade as alembic_upgrade from fastapi.testclient import TestClient @@ -30,13 +31,13 @@ def pytest_unconfigure(config): @pytest.fixture -def sqlite_in_memory(): +def sqlite_in_memory() -> bool: """whether to create a sqlite DB in memory or on the filesystem.""" return True @pytest.fixture -def sqlite_url(sqlite_in_memory): +def sqlite_url(sqlite_in_memory: bool) -> str: if sqlite_in_memory: yield "sqlite:///:memory:" else: @@ -48,19 +49,19 @@ def sqlite_url(sqlite_in_memory): @pytest.fixture -def database_url(sqlite_url): +def database_url(sqlite_url: str) -> str: db_url = os.environ.get("QUETZ_TEST_DATABASE", sqlite_url) return db_url @pytest.fixture -def sql_echo(): +def sql_echo() -> bool: """whether to activate SQL echo during the tests or not.""" return False @pytest.fixture -def engine(database_url, sql_echo): +def engine(database_url: str, sql_echo: bool) -> sqlalchemy.Engine: sql_echo = bool(os.environ.get("QUETZ_TEST_ECHO_SQL", sql_echo)) engine = get_engine(database_url, echo=sql_echo, reuse_engine=False) yield engine @@ -119,7 +120,9 @@ def auto_rollback(): @pytest.fixture -def session_maker(sql_connection, create_tables, auto_rollback) -> Iterator[Session]: +def session_maker( + sql_connection: sqlalchemy.Connection, create_tables, auto_rollback: bool +) -> Iterator[sqlalchemy.orm.sessionmaker]: # run the tests with a separate external DB transaction # so that we can easily rollback all db changes (even if committed) # done by the test client @@ -142,14 +145,22 @@ def session_maker(sql_connection, create_tables, auto_rollback) -> Iterator[Sess @pytest.fixture -def expires_on_commit(): - return True +def session_maker_expire_on_commit( + session_maker, +) -> Callable[[], sqlalchemy.orm.sessionmaker]: + def maker(*args, **kwargs) -> sqlalchemy.orm.Session: + session = session_maker() + session.expire_on_commit = True + return session + + return maker @pytest.fixture -def db(session_maker, expires_on_commit): - session = session_maker() - session.expire_on_commit = expires_on_commit +def db( + session_maker_expire_on_commit: sqlalchemy.orm.sessionmaker, +) -> Iterator[sqlalchemy.orm.Session]: + session = session_maker_expire_on_commit() yield session session.close() diff --git a/quetz/tests/test_cli.py b/quetz/tests/test_cli.py index f6c7769b..a6f8a2bc 100644 --- a/quetz/tests/test_cli.py +++ b/quetz/tests/test_cli.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock import pytest +import sqlalchemy import sqlalchemy as sa from alembic.script import ScriptDirectory from pytest_mock.plugin import MockerFixture @@ -49,14 +50,14 @@ def user_with_identity(user, db): return identity -def get_user(db, config_dir, username="bartosz"): - def get_db(_): - return db - - with mock.patch("quetz.cli.get_session", get_db): +def get_user( + session_maker: sqlalchemy.orm.sessionmaker, config_dir, username="bartosz" +): + with mock.patch("quetz.cli.get_session", session_maker): cli.add_user_roles(config_dir) - return db.query(User).filter(User.username == username).one_or_none() + with session_maker() as db: + return db.query(User).filter(User.username == username).one_or_none() def test_init_db(db, config, config_dir, mocker): @@ -81,9 +82,13 @@ def test_create_user_from_config( @pytest.mark.parametrize("user_group", [None]) def test_set_user_roles_no_user( - db, config, config_dir, user_group, mocker: MockerFixture + session_maker_expire_on_commit, + config, + config_dir, + user_group, + mocker: MockerFixture, ): - user = get_user(db, config_dir) + user = get_user(session_maker_expire_on_commit, config_dir) assert user is None From 8718ffe547c8211666ec161142a074683a585aaa Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Thu, 11 Jul 2024 12:00:40 +0200 Subject: [PATCH 04/24] wip --- quetz/database.py | 4 ++-- quetz/testing/fixtures.py | 2 +- quetz/tests/test_cli.py | 30 +++++++++++++++++------------- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/quetz/database.py b/quetz/database.py index 27cc1c9e..9810b12d 100644 --- a/quetz/database.py +++ b/quetz/database.py @@ -64,8 +64,8 @@ def on_close(dbapi_conn, conn_record): def get_session_maker( bind: sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, -) -> Callable[[], Session]: - return sessionmaker(autocommit=False, autoflush=True, bind=bin) +) -> Callable[[], sessionmaker]: + return sessionmaker(autocommit=False, autoflush=True, bind=bind) def get_session(config: Config | None) -> Session: diff --git a/quetz/testing/fixtures.py b/quetz/testing/fixtures.py index c03852c1..deba4dba 100644 --- a/quetz/testing/fixtures.py +++ b/quetz/testing/fixtures.py @@ -146,7 +146,7 @@ def session_maker( @pytest.fixture def session_maker_expire_on_commit( - session_maker, + session_maker: sqlalchemy.orm.sessionmaker, ) -> Callable[[], sqlalchemy.orm.sessionmaker]: def maker(*args, **kwargs) -> sqlalchemy.orm.Session: session = session_maker() diff --git a/quetz/tests/test_cli.py b/quetz/tests/test_cli.py index a6f8a2bc..76116371 100644 --- a/quetz/tests/test_cli.py +++ b/quetz/tests/test_cli.py @@ -4,6 +4,7 @@ import tempfile from multiprocessing import Process from pathlib import Path +from typing import Callable from unittest import mock from unittest.mock import MagicMock @@ -71,9 +72,9 @@ def test_init_db(db, config, config_dir, mocker): [("admins", "owner"), ("maintainers", "maintainer"), ("members", "member")], ) def test_create_user_from_config( - db, config, config_dir, user_group, expected_role, mocker, user_with_identity + session_maker_expire_on_commit, config, config_dir, user_group, expected_role, mocker, user_with_identity ): - user = get_user(db, config_dir) + user = get_user(session_maker_expire_on_commit, config_dir) assert user assert user.role == expected_role @@ -94,9 +95,9 @@ def test_set_user_roles_no_user( def test_set_user_roles_user_exists( - db, config, config_dir, user, mocker, user_with_identity + session_maker_expire_on_commit, config, config_dir, user, mocker, user_with_identity ): - user = get_user(db, config_dir) + user = get_user(session_maker_expire_on_commit, config_dir) assert user assert user.role == "owner" @@ -106,13 +107,17 @@ def test_set_user_roles_user_exists( @pytest.mark.parametrize("default_role", [None, "member"]) @pytest.mark.parametrize("current_role", ["owner", "member", "maintainer"]) def test_set_user_roles_user_has_role( - db, config, config_dir, user, mocker, user_with_identity, current_role, default_role + session_maker_expire_on_commit: sqlalchemy.orm.sessionmaker, config: Config, config_dir: str, user: User, mocker, user_with_identity: Identity, current_role: str, default_role: str | None ): - user.role = current_role - db.commit() - user = get_user(db, config_dir) + + with session_maker_expire_on_commit() as db: + user.role = current_role + db.commit() + + user = get_user(session_maker_expire_on_commit, config_dir) assert user + # TODO: I do not understand this test. Why is default_role parametrized? # role shouldn't be changed unless it's default role if current_role != default_role: assert user.role == current_role @@ -122,20 +127,19 @@ def test_set_user_roles_user_has_role( @pytest.mark.parametrize("config_extra", ['[users]\nadmins = ["dummy:alice"]\n']) -def test_init_db_create_test_users(db, config, mocker, config_dir): +def test_init_db_create_test_users(session_maker_expire_on_commit: Callable[[], sqlalchemy.orm.Session], config, mocker, config_dir): _run_migrations: MagicMock = mocker.patch("quetz.cli._run_migrations") - def get_db(_): - return db - with mock.patch("quetz.cli.get_session", get_db): + with mock.patch("quetz.cli.get_session", session_maker_expire_on_commit): cli.create( Path(config_dir) / "new-deployment", copy_conf="config.toml", dev=True, ) - user = db.query(User).filter(User.username == "alice").one_or_none() + with session_maker_expire_on_commit() as db: + user = db.query(User).filter(User.username == "alice").one_or_none() assert user.role == "owner" From bad5a881e473fecae6da26480a4178ea265fcc64 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Fri, 12 Jul 2024 11:36:25 +0200 Subject: [PATCH 05/24] wip --- quetz/tests/test_cli.py | 35 ++++++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/quetz/tests/test_cli.py b/quetz/tests/test_cli.py index 76116371..106942c2 100644 --- a/quetz/tests/test_cli.py +++ b/quetz/tests/test_cli.py @@ -72,7 +72,13 @@ def test_init_db(db, config, config_dir, mocker): [("admins", "owner"), ("maintainers", "maintainer"), ("members", "member")], ) def test_create_user_from_config( - session_maker_expire_on_commit, config, config_dir, user_group, expected_role, mocker, user_with_identity + session_maker_expire_on_commit, + config, + config_dir, + user_group, + expected_role, + mocker, + user_with_identity, ): user = get_user(session_maker_expire_on_commit, config_dir) assert user @@ -107,30 +113,41 @@ def test_set_user_roles_user_exists( @pytest.mark.parametrize("default_role", [None, "member"]) @pytest.mark.parametrize("current_role", ["owner", "member", "maintainer"]) def test_set_user_roles_user_has_role( - session_maker_expire_on_commit: sqlalchemy.orm.sessionmaker, config: Config, config_dir: str, user: User, mocker, user_with_identity: Identity, current_role: str, default_role: str | None + session_maker_expire_on_commit: sqlalchemy.orm.sessionmaker, + config: Config, + config_dir: str, + user: User, + mocker, + user_with_identity: Identity, + current_role: str, + default_role: str | None, ): - + # Arrange: The user has `current_role` before we call the CLI with session_maker_expire_on_commit() as db: user.role = current_role db.commit() + # Act: Call the CLI user = get_user(session_maker_expire_on_commit, config_dir) assert user - # TODO: I do not understand this test. Why is default_role parametrized? - # role shouldn't be changed unless it's default role + # Assert: role shouldn't be changed unless it's default role if current_role != default_role: - assert user.role == current_role + assert current_role == user.role else: - assert user.role == "owner" + assert "owner" == user.role assert user.username == "bartosz" @pytest.mark.parametrize("config_extra", ['[users]\nadmins = ["dummy:alice"]\n']) -def test_init_db_create_test_users(session_maker_expire_on_commit: Callable[[], sqlalchemy.orm.Session], config, mocker, config_dir): +def test_init_db_create_test_users( + session_maker_expire_on_commit: Callable[[], sqlalchemy.orm.Session], + config, + mocker, + config_dir, +): _run_migrations: MagicMock = mocker.patch("quetz.cli._run_migrations") - with mock.patch("quetz.cli.get_session", session_maker_expire_on_commit): cli.create( Path(config_dir) / "new-deployment", From b3b8d28be0e519d4bb322f3036fb0bf45e1d8e29 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Fri, 12 Jul 2024 11:45:18 +0200 Subject: [PATCH 06/24] wip? --- .../tests/test_quetz_conda_suggest.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py b/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py index d22c9388..1e2e49c7 100644 --- a/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py +++ b/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py @@ -17,7 +17,9 @@ def test_conda_suggest_endpoint_without_upload(client, channel, subdir): assert response.json() == None # noqa: E711 -def test_post_add_package_version(package_version, db, config): +def test_post_add_package_version( + package_version, db, config, session_maker_expire_on_commit +): filename = "test-package-0.1-0.tar.bz2" with tempfile.SpooledTemporaryFile(mode="wb") as target: @@ -26,12 +28,11 @@ def test_post_add_package_version(package_version, db, config): target.seek(0) condainfo = CondaInfo(target, filename) - def get_db(): - return db - from quetz_conda_suggest import main - with mock.patch("quetz_conda_suggest.main.get_session", get_db): + with mock.patch( + "quetz_conda_suggest.main.get_session", session_maker_expire_on_commit + ): main.post_add_package_version(package_version, condainfo) meta = db.query(db_models.CondaSuggestMetadata).first() @@ -48,7 +49,9 @@ def get_db(): b"lib/libtpkg.so\n", b"lib/pkgconfig/libtpkg.pc\n", ] - with mock.patch("quetz_conda_suggest.main.get_session", get_db): + with mock.patch( + "quetz_conda_suggest.main.get_session", session_maker_expire_on_commit + ): main.post_add_package_version(package_version, condainfo) meta = db.query(db_models.CondaSuggestMetadata).all() From b3a120f537be1a91e6d3d2bbdfd90e2597ce39b9 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Fri, 12 Jul 2024 15:29:51 +0200 Subject: [PATCH 07/24] fix --- quetz/tests/test_cli.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/quetz/tests/test_cli.py b/quetz/tests/test_cli.py index 106942c2..fe8f8e3f 100644 --- a/quetz/tests/test_cli.py +++ b/quetz/tests/test_cli.py @@ -122,9 +122,11 @@ def test_set_user_roles_user_has_role( current_role: str, default_role: str | None, ): - # Arrange: The user has `current_role` before we call the CLI + # Arrange: Assign `current_role` to the user before we call the CLI with session_maker_expire_on_commit() as db: + user = db.query(User).filter(User.username == "bartosz").one_or_none() user.role = current_role + assert user.role == current_role db.commit() # Act: Call the CLI From 2da1136075b924e90dbd0463b536b55ebb1c85f3 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Fri, 12 Jul 2024 16:34:30 +0200 Subject: [PATCH 08/24] fix --- .../tests/test_quetz_conda_suggest.py | 12 +++------ plugins/quetz_runexports/tests/conftest.py | 21 +++++++++------ .../tests/test_quetz_runexports.py | 27 ++++++++++++------- quetz/testing/fixtures.py | 24 ++++------------- quetz/tests/test_cli.py | 24 ++++++++--------- quetz/tests/test_dao.py | 3 ++- 6 files changed, 53 insertions(+), 58 deletions(-) diff --git a/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py b/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py index 1e2e49c7..a780820e 100644 --- a/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py +++ b/plugins/quetz_conda_suggest/tests/test_quetz_conda_suggest.py @@ -17,9 +17,7 @@ def test_conda_suggest_endpoint_without_upload(client, channel, subdir): assert response.json() == None # noqa: E711 -def test_post_add_package_version( - package_version, db, config, session_maker_expire_on_commit -): +def test_post_add_package_version(package_version, db, config, session_maker): filename = "test-package-0.1-0.tar.bz2" with tempfile.SpooledTemporaryFile(mode="wb") as target: @@ -30,9 +28,7 @@ def test_post_add_package_version( from quetz_conda_suggest import main - with mock.patch( - "quetz_conda_suggest.main.get_session", session_maker_expire_on_commit - ): + with mock.patch("quetz_conda_suggest.main.get_session", session_maker): main.post_add_package_version(package_version, condainfo) meta = db.query(db_models.CondaSuggestMetadata).first() @@ -49,9 +45,7 @@ def test_post_add_package_version( b"lib/libtpkg.so\n", b"lib/pkgconfig/libtpkg.pc\n", ] - with mock.patch( - "quetz_conda_suggest.main.get_session", session_maker_expire_on_commit - ): + with mock.patch("quetz_conda_suggest.main.get_session", session_maker): main.post_add_package_version(package_version, condainfo) meta = db.query(db_models.CondaSuggestMetadata).all() diff --git a/plugins/quetz_runexports/tests/conftest.py b/plugins/quetz_runexports/tests/conftest.py index 1cfd3823..9398de0f 100644 --- a/plugins/quetz_runexports/tests/conftest.py +++ b/plugins/quetz_runexports/tests/conftest.py @@ -2,11 +2,12 @@ import uuid from pytest import fixture -from quetz_runexports import db_models +from sqlalchemy.orm import Session +from plugins.quetz_runexports.quetz_runexports.db_models import PackageVersionMetadata from quetz import rest_models from quetz.dao import Dao -from quetz.db_models import User +from quetz.db_models import User, PackageVersion, Package, Channel pytest_plugins = "quetz.testing.fixtures" @@ -17,7 +18,7 @@ def dao(db) -> Dao: @fixture -def user(db): +def user(db: Session) -> User: user = User(id=uuid.uuid4().bytes, username="bartosz") db.add(user) db.commit() @@ -25,7 +26,7 @@ def user(db): @fixture -def channel(dao, user, db): +def channel(dao: Dao, user: User, db: Session): channel_data = rest_models.Channel( name="test-mirror-channel", private=False, @@ -42,7 +43,7 @@ def channel(dao, user, db): @fixture -def package(dao, user, channel, db): +def package(dao: Dao, user: User, channel: Channel, db: Session) -> Package: new_package_data = rest_models.Package(name="test-package") package = dao.create_package( @@ -59,7 +60,9 @@ def package(dao, user, channel, db): @fixture -def package_version(user, channel, db, dao, package): +def package_version( + dao: Dao, user: User, channel: Channel, db: Session, package: Package +) -> PackageVersion: # create package version that will added to local repodata package_format = "tarbz2" package_info = '{"size": 5000, "subdirs":["noarch"]}' @@ -85,8 +88,10 @@ def package_version(user, channel, db, dao, package): @fixture -def package_runexports(package_version, db): - meta = db_models.PackageVersionMetadata( +def package_runexports( + db: Session, package_version: PackageVersion +) -> PackageVersionMetadata: + meta = PackageVersionMetadata( version_id=package_version.id, data=json.dumps({"weak": ["somepackage > 3.0"]}), ) diff --git a/plugins/quetz_runexports/tests/test_quetz_runexports.py b/plugins/quetz_runexports/tests/test_quetz_runexports.py index 178eb191..c4ea6501 100644 --- a/plugins/quetz_runexports/tests/test_quetz_runexports.py +++ b/plugins/quetz_runexports/tests/test_quetz_runexports.py @@ -5,8 +5,13 @@ import pytest from quetz_runexports import db_models +from sqlalchemy.orm import Session +from starlette.testclient import TestClient +from plugins.quetz_runexports.quetz_runexports.db_models import PackageVersionMetadata from quetz.condainfo import CondaInfo +from quetz.config import Config +from quetz.db_models import PackageVersion, Package, Channel pytest_plugins = "quetz.testing.fixtures" @@ -17,13 +22,12 @@ def plugins(): def test_run_exports_endpoint( - client, - channel, - package, - package_version, - package_runexports, - db, - session_maker, + client: TestClient, + channel: Channel, + package: Package, + package_version: PackageVersion, + package_runexports: PackageVersionMetadata, + db: Session, ): filename = package_version.filename platform = package_version.platform @@ -36,7 +40,10 @@ def test_run_exports_endpoint( def test_endpoint_without_metadata( - client, channel, package, package_version, db, session_maker + client: TestClient, + channel: Channel, + package: Package, + package_version: PackageVersion, ): filename = package_version.filename platform = package_version.platform @@ -47,7 +54,9 @@ def test_endpoint_without_metadata( assert response.status_code == 404 -def test_post_add_package_version(package_version, config, db, session_maker): +def test_post_add_package_version( + package_version: PackageVersion, config: Config, db: Session +): filename = "test-package-0.1-0.tar.bz2" with tempfile.SpooledTemporaryFile(mode="wb") as target: diff --git a/quetz/testing/fixtures.py b/quetz/testing/fixtures.py index deba4dba..0f042d7b 100644 --- a/quetz/testing/fixtures.py +++ b/quetz/testing/fixtures.py @@ -1,7 +1,7 @@ import os import shutil import tempfile -from typing import List, Iterator, Callable +from typing import List, Iterator import pytest import sqlalchemy.orm @@ -14,7 +14,6 @@ from quetz.dao import Dao from quetz.database import get_engine, get_session_maker from quetz.db_models import Base -from sqlalchemy.orm import Session def pytest_configure(config): @@ -144,25 +143,12 @@ def session_maker( trans.rollback() -@pytest.fixture -def session_maker_expire_on_commit( - session_maker: sqlalchemy.orm.sessionmaker, -) -> Callable[[], sqlalchemy.orm.sessionmaker]: - def maker(*args, **kwargs) -> sqlalchemy.orm.Session: - session = session_maker() - session.expire_on_commit = True - return session - - return maker - - @pytest.fixture def db( - session_maker_expire_on_commit: sqlalchemy.orm.sessionmaker, + session_maker: sqlalchemy.orm.sessionmaker, ) -> Iterator[sqlalchemy.orm.Session]: - session = session_maker_expire_on_commit() - yield session - session.close() + with session_maker() as db: + yield db @pytest.fixture @@ -282,7 +268,7 @@ def root_endpoint(): @pytest.fixture -def client(app): +def client(app) -> TestClient: client = TestClient(app) return client diff --git a/quetz/tests/test_cli.py b/quetz/tests/test_cli.py index fe8f8e3f..7d678112 100644 --- a/quetz/tests/test_cli.py +++ b/quetz/tests/test_cli.py @@ -72,7 +72,7 @@ def test_init_db(db, config, config_dir, mocker): [("admins", "owner"), ("maintainers", "maintainer"), ("members", "member")], ) def test_create_user_from_config( - session_maker_expire_on_commit, + session_maker, config, config_dir, user_group, @@ -80,7 +80,7 @@ def test_create_user_from_config( mocker, user_with_identity, ): - user = get_user(session_maker_expire_on_commit, config_dir) + user = get_user(session_maker, config_dir) assert user assert user.role == expected_role @@ -89,21 +89,21 @@ def test_create_user_from_config( @pytest.mark.parametrize("user_group", [None]) def test_set_user_roles_no_user( - session_maker_expire_on_commit, + session_maker, config, config_dir, user_group, mocker: MockerFixture, ): - user = get_user(session_maker_expire_on_commit, config_dir) + user = get_user(session_maker, config_dir) assert user is None def test_set_user_roles_user_exists( - session_maker_expire_on_commit, config, config_dir, user, mocker, user_with_identity + session_maker, config, config_dir, user, mocker, user_with_identity ): - user = get_user(session_maker_expire_on_commit, config_dir) + user = get_user(session_maker, config_dir) assert user assert user.role == "owner" @@ -113,7 +113,7 @@ def test_set_user_roles_user_exists( @pytest.mark.parametrize("default_role", [None, "member"]) @pytest.mark.parametrize("current_role", ["owner", "member", "maintainer"]) def test_set_user_roles_user_has_role( - session_maker_expire_on_commit: sqlalchemy.orm.sessionmaker, + session_maker: sqlalchemy.orm.sessionmaker, config: Config, config_dir: str, user: User, @@ -123,14 +123,14 @@ def test_set_user_roles_user_has_role( default_role: str | None, ): # Arrange: Assign `current_role` to the user before we call the CLI - with session_maker_expire_on_commit() as db: + with session_maker() as db: user = db.query(User).filter(User.username == "bartosz").one_or_none() user.role = current_role assert user.role == current_role db.commit() # Act: Call the CLI - user = get_user(session_maker_expire_on_commit, config_dir) + user = get_user(session_maker, config_dir) assert user # Assert: role shouldn't be changed unless it's default role @@ -143,21 +143,21 @@ def test_set_user_roles_user_has_role( @pytest.mark.parametrize("config_extra", ['[users]\nadmins = ["dummy:alice"]\n']) def test_init_db_create_test_users( - session_maker_expire_on_commit: Callable[[], sqlalchemy.orm.Session], + session_maker: Callable[[], sqlalchemy.orm.Session], config, mocker, config_dir, ): _run_migrations: MagicMock = mocker.patch("quetz.cli._run_migrations") - with mock.patch("quetz.cli.get_session", session_maker_expire_on_commit): + with mock.patch("quetz.cli.get_session", session_maker): cli.create( Path(config_dir) / "new-deployment", copy_conf="config.toml", dev=True, ) - with session_maker_expire_on_commit() as db: + with session_maker() as db: user = db.query(User).filter(User.username == "alice").one_or_none() assert user.role == "owner" diff --git a/quetz/tests/test_dao.py b/quetz/tests/test_dao.py index 5ccaef7c..31f57a75 100644 --- a/quetz/tests/test_dao.py +++ b/quetz/tests/test_dao.py @@ -3,6 +3,7 @@ import pytest from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session from sqlalchemy.orm.exc import ObjectDeletedError from quetz import errors, rest_models @@ -170,7 +171,7 @@ def test_update_channel_size(dao, channel, db, package_version): def test_increment_download_count( - dao: Dao, channel, db, package_version, session_maker + dao: Dao, channel: Channel, db: Session, package_version: PackageVersion ): assert package_version.download_count == 0 now = datetime.datetime(2020, 10, 1, 10, 1, 10) From 72f08a107b39160eb1a5fe8dd9722f63cc8df3e5 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Fri, 12 Jul 2024 16:43:28 +0200 Subject: [PATCH 09/24] fix --- quetz/database.py | 4 +- quetz/tasks/workers.py | 136 ++++++++++++++++-------------------- quetz/tests/test_workers.py | 10 +-- 3 files changed, 69 insertions(+), 81 deletions(-) diff --git a/quetz/database.py b/quetz/database.py index 9810b12d..48e05487 100644 --- a/quetz/database.py +++ b/quetz/database.py @@ -70,8 +70,8 @@ def get_session_maker( def get_session(config: Config | None) -> Session: """Get a database session. - ea - Important note: this function is mocked during tests! + + Important note: this function is mocked during tests! """ if config is None: diff --git a/quetz/tasks/workers.py b/quetz/tasks/workers.py index 98dbacd3..2b98b61a 100644 --- a/quetz/tasks/workers.py +++ b/quetz/tasks/workers.py @@ -137,87 +137,75 @@ def job_wrapper( logger = logging.getLogger("quetz.worker") pkgstore = kwargs.pop("pkgstore", None) - db = kwargs.pop("db", None) - dao = kwargs.pop("dao", None) auth = kwargs.pop("auth", None) session = kwargs.pop("session", None) - if db: - close_session = False - elif dao: - db = dao.db - close_session = False - else: - db = get_session(config) - close_session = True - - user_id: Optional[str] - if task_id: - task = db.query(Task).filter(Task.id == task_id).one_or_none() - if not task: - raise KeyError(f"Task '{task_id}' not found") - # take extra arguments from job definition - if task.job.extra_args: - job_extra_args = json.loads(task.job.extra_args) - kwargs.update(job_extra_args) - if task.job.owner_id: - user_id = str(uuid.UUID(bytes=task.job.owner_id)) + with get_session(config) as db: + user_id: Optional[str] + if task_id: + task = db.query(Task).filter(Task.id == task_id).one_or_none() + if not task: + raise KeyError(f"Task '{task_id}' not found") + # take extra arguments from job definition + if task.job.extra_args: + job_extra_args = json.loads(task.job.extra_args) + kwargs.update(job_extra_args) + if task.job.owner_id: + user_id = str(uuid.UUID(bytes=task.job.owner_id)) + else: + user_id = None else: + task = None user_id = None - else: - task = None - user_id = None - - if not pkgstore: - pkgstore = config.get_package_store() - - dao = Dao(db) - - if not auth: - browser_session: Dict[str, str] = {} - api_key = None - if user_id: - browser_session["user_id"] = user_id - auth = Rules(api_key, browser_session, db) - if not session: - session = get_remote_session() - - if task: - task.status = TaskStatus.running - task.job.status = JobStatus.running - db.commit() - - callable_f: Callable = pickle.loads(func) if isinstance(func, bytes) else func - - extra_kwargs = prepare_arguments( - callable_f, - dao=dao, - auth=auth, - session=session, - config=config, - pkgstore=pkgstore, - user_id=user_id, - ) - - kwargs.update(extra_kwargs) - - try: - callable_f(**kwargs) - except Exception as exc: + + if not pkgstore: + pkgstore = config.get_package_store() + + dao = Dao(db) + + if not auth: + browser_session: Dict[str, str] = {} + api_key = None + if user_id: + browser_session["user_id"] = user_id + auth = Rules(api_key, browser_session, db) + if not session: + session = get_remote_session() + if task: - task.status = TaskStatus.failed - logger.error( - f"exception occurred when evaluating function {callable_f.__name__}:{exc}" + task.status = TaskStatus.running + task.job.status = JobStatus.running + db.commit() + + callable_f: Callable = pickle.loads(func) if isinstance(func, bytes) else func + + extra_kwargs = prepare_arguments( + callable_f, + dao=dao, + auth=auth, + session=session, + config=config, + pkgstore=pkgstore, + user_id=user_id, ) - if exc_passthrou: - raise exc - else: - if task: - task.status = TaskStatus.success - finally: - db.commit() - if close_session: - db.close() + + kwargs.update(extra_kwargs) + + try: + callable_f(**kwargs) + except Exception as exc: + if task: + task.status = TaskStatus.failed + logger.error( + f"exception occurred when evaluating function {callable_f.__name__}:{exc}" + ) + if exc_passthrou: + raise exc + else: + if task: + task.status = TaskStatus.success + finally: + db.commit() class AbstractWorker: diff --git a/quetz/tests/test_workers.py b/quetz/tests/test_workers.py index c109cbb2..b8a1353f 100644 --- a/quetz/tests/test_workers.py +++ b/quetz/tests/test_workers.py @@ -132,11 +132,11 @@ def db_cleanup(config): from quetz.database import get_session - db = get_session(config.sqlalchemy_database_url) - user = db.query(User).one_or_none() - if user: - db.delete(user) - db.commit() + with get_session(config) as db: + user = db.query(User).one_or_none() + if user: + db.delete(user) + db.commit() @pytest.mark.asyncio From 1e3fc11a8a3ec561ced56ea3bc4cf89227f28a95 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Fri, 12 Jul 2024 16:58:22 +0200 Subject: [PATCH 10/24] fix --- .../tests/test_main.py | 193 +++++++++--------- 1 file changed, 97 insertions(+), 96 deletions(-) diff --git a/plugins/quetz_repodata_patching/tests/test_main.py b/plugins/quetz_repodata_patching/tests/test_main.py index dce0428d..8be5b873 100644 --- a/plugins/quetz_repodata_patching/tests/test_main.py +++ b/plugins/quetz_repodata_patching/tests/test_main.py @@ -4,21 +4,24 @@ import tarfile import time import uuid -from contextlib import contextmanager +from io import BytesIO from unittest import mock from zipfile import ZipFile import pytest import zstandard +from sqlalchemy.orm import sessionmaker, Session -import quetz -from quetz.db_models import Package, Profile, User +from quetz.config import Config +from quetz.dao import Dao +from quetz.db_models import Package, Profile, User, PackageVersion +from quetz.pkgstores import PackageStore from quetz.rest_models import Channel from quetz.tasks import indexing @pytest.fixture -def user(db): +def user(db: Session) -> User: user = User(id=uuid.uuid4().bytes, username="bartosz") profile = Profile(name="Bartosz", avatar_url="http:///avatar", user=user) db.add(user) @@ -28,22 +31,22 @@ def user(db): @pytest.fixture -def channel_name(): +def channel_name() -> str: return "my-channel" @pytest.fixture -def package_name(): +def package_name() -> str: return "mytestpackage" @pytest.fixture -def package_format(): +def package_format() -> str: return "tarbz2" @pytest.fixture -def package_file_name(package_name, package_format): +def package_file_name(package_name: str, package_format: str) -> str: if package_format == "tarbz2": return f"{package_name}-0.1-0.tar.bz2" elif package_format == "conda": @@ -51,27 +54,27 @@ def package_file_name(package_name, package_format): @pytest.fixture -def channel(dao: "quetz.dao.Dao", channel_name, user): +def channel(dao: Dao, channel_name: str, user: User) -> Channel: channel_data = Channel(name=channel_name, private=False) channel = dao.create_channel(channel_data, user.id, "owner") return channel @pytest.fixture -def package_subdir(): +def package_subdir() -> str: return "noarch" @pytest.fixture def package_version( - dao: "quetz.dao.Dao", - user, - channel, - package_name, - db, - package_file_name, - package_format, - package_subdir, + dao: Dao, + user: User, + channel: Channel, + package_name: str, + db: Session, + package_file_name: str, + package_format: str, + package_subdir: str, ): channel_data = json.dumps({"subdirs": [package_subdir]}) package_data = Package(name=package_name) @@ -101,13 +104,13 @@ def package_version( @pytest.fixture -def repodata_name(channel): +def repodata_name(channel: Channel) -> str: package_name = f"{channel.name}-repodata-patches" return package_name @pytest.fixture -def repodata_file_name(repodata_name, archive_format): +def repodata_file_name(repodata_name: str, archive_format: str) -> str: version = "0.1" build_str = "0" ext = "tar.bz2" if archive_format == "tarbz2" else "conda" @@ -115,17 +118,17 @@ def repodata_file_name(repodata_name, archive_format): @pytest.fixture -def revoke_instructions(): +def revoke_instructions() -> list[str]: return [] @pytest.fixture -def remove_instructions(): +def remove_instructions() -> list[str]: return [] @pytest.fixture -def patched_package_name(package_file_name): +def patched_package_name(package_file_name: str) -> str: "name of the package in patch_instructions" # by default the name of the package in patch_instructions is the same # as the name of the dummy package @@ -135,7 +138,11 @@ def patched_package_name(package_file_name): @pytest.fixture -def patch_content(patched_package_name, revoke_instructions, remove_instructions): +def patch_content( + patched_package_name: str, + revoke_instructions: list[str], + remove_instructions: list[str], +) -> dict: d = {} package_file_name = patched_package_name @@ -154,19 +161,22 @@ def patch_content(patched_package_name, revoke_instructions, remove_instructions @pytest.fixture -def archive_format(): +def archive_format() -> str: return "tarbz2" @pytest.fixture() -def patches_subdir(): +def patches_subdir() -> str: return "noarch" @pytest.fixture -def repodata_archive(repodata_file_name, patch_content, archive_format, patches_subdir): - from io import BytesIO - +def repodata_archive( + repodata_file_name: str, + patch_content: dict, + archive_format: str, + patches_subdir: str, +) -> BytesIO: patch_instructions = json.dumps(patch_content).encode("ascii") def mk_tarfile(patch_instructions, compr=None): @@ -208,16 +218,20 @@ def mk_tarfile(patch_instructions, compr=None): @pytest.fixture def package_repodata_patches( - dao: "quetz.dao.Dao", - user, - channel, - db, - pkgstore, - repodata_name, - repodata_file_name, - repodata_archive, - archive_format, -): + session_maker: sessionmaker, + pkgstore: PackageStore, + package_version: PackageVersion, + channel_name: str, + package_file_name: str, + dao: Dao, + db: Session, + user: User, + channel: Channel, + repodata_name: str, + repodata_file_name: str, + repodata_archive: BytesIO, + archive_format: str, +) -> PackageVersion: package_name = repodata_name package_data = Package(name=package_name) @@ -244,7 +258,7 @@ def package_repodata_patches( @pytest.fixture -def pkgstore(config): +def pkgstore(config: Config) -> PackageStore: pkgstore = config.get_package_store() return pkgstore @@ -282,25 +296,22 @@ def pkgstore(config): ], ) def test_post_package_indexing( - pkgstore, - dao, - package_version, - channel_name, - package_repodata_patches, - db, - package_file_name, - repodata_stem, - compressed_repodata, - revoke_instructions, - remove_instructions, - package_format, - patched_package_name, + session_maker: sessionmaker, + pkgstore: PackageStore, + package_version: PackageVersion, + package_repodata_patches: PackageVersion, + channel_name: str, + package_file_name: str, + dao: Dao, + db: Session, + repodata_stem: str, + compressed_repodata: bool, + revoke_instructions: list[str], + remove_instructions: list[str], + package_format: str, + patched_package_name: str, ): - @contextmanager - def get_db(): - yield db - - with mock.patch("quetz_repodata_patching.main.get_session", get_db): + with mock.patch("quetz_repodata_patching.main.get_session", session_maker): indexing.update_indexes(dao, pkgstore, channel_name) ext = "json.bz2" if compressed_repodata else "json" @@ -365,20 +376,17 @@ def get_db(): ], ) def test_index_html( - pkgstore, - package_version, - package_repodata_patches, - channel_name, - package_file_name, - dao, - db, - remove_instructions, + session_maker: sessionmaker, + pkgstore: PackageStore, + package_version: PackageVersion, + package_repodata_patches: PackageVersion, + channel_name: str, + package_file_name: str, + dao: Dao, + db: Session, + remove_instructions: list[str], ): - @contextmanager - def get_db(): - yield db - - with mock.patch("quetz_repodata_patching.main.get_session", get_db): + with mock.patch("quetz_repodata_patching.main.get_session", session_maker): indexing.update_indexes(dao, pkgstore, channel_name) index_path = os.path.join( @@ -405,21 +413,18 @@ def get_db(): @pytest.mark.parametrize("package_subdir", ["linux-64", "noarch"]) @pytest.mark.parametrize("patches_subdir", ["linux-64", "noarch"]) def test_patches_for_subdir( - pkgstore, - package_version, - channel_name, - package_file_name, - package_repodata_patches, - dao, - db, - package_subdir, - patches_subdir, + pkgstore: PackageStore, + package_version: PackageVersion, + package_repodata_patches: PackageVersion, + channel_name: str, + package_file_name: str, + dao: Dao, + db: Session, + package_subdir: str, + patches_subdir: str, + session_maker: sessionmaker, ): - @contextmanager - def get_db(): - yield db - - with mock.patch("quetz_repodata_patching.main.get_session", get_db): + with mock.patch("quetz_repodata_patching.main.get_session", session_maker): indexing.update_indexes(dao, pkgstore, channel_name) index_path = os.path.join( @@ -460,18 +465,14 @@ def get_db(): def test_no_repodata_patches_package( - pkgstore, - package_version, - channel_name, - package_file_name, - dao, - db, + pkgstore: PackageStore, + package_version: PackageVersion, + channel_name: str, + package_file_name: str, + dao: Dao, + session_maker: sessionmaker, ): - @contextmanager - def get_db(): - yield db - - with mock.patch("quetz_repodata_patching.main.get_session", get_db): + with mock.patch("quetz_repodata_patching.main.get_session", session_maker): indexing.update_indexes(dao, pkgstore, channel_name) index_path = os.path.join( From 1b570a07a0e87b997d35133e48687ac07154790a Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 15 Jul 2024 10:00:51 +0200 Subject: [PATCH 11/24] fix --- quetz/dao.py | 24 +++++++++++++----------- quetz/database.py | 4 +++- quetz/tasks/workers.py | 1 + quetz/tests/conftest.py | 9 ++++++--- quetz/tests/test_auth.py | 4 ++-- quetz/tests/test_dao.py | 8 ++++---- 6 files changed, 29 insertions(+), 21 deletions(-) diff --git a/quetz/dao.py b/quetz/dao.py index 7455ccff..d5f851da 100644 --- a/quetz/dao.py +++ b/quetz/dao.py @@ -20,6 +20,7 @@ from quetz import channel_data, errors, rest_models, versionorder from quetz.database_extensions import version_match from quetz.utils import apply_custom_query +from .condainfo import CondaInfo from .db_models import ( ApiKey, @@ -33,6 +34,7 @@ PackageVersion, Profile, User, + PackageFormatEnum, ) from .jobs.models import Job, JobStatus, Task, TaskStatus from .metrics.db_models import ( @@ -809,17 +811,17 @@ def get_api_key(self, key): def create_version( self, - channel_name, - package_name, - package_format, - platform, - version, - build_number, - build_string, - filename, - info, - uploader_id, - size, + channel_name: str, + package_name: str, + package_format: PackageFormatEnum, + platform: str, + version: str, + build_number: int, + build_string: str, + filename: str, + info: CondaInfo, + uploader_id: bytes, + size: int, upsert: bool = False, ): # hold a lock on the package diff --git a/quetz/database.py b/quetz/database.py index 48e05487..f0d18015 100644 --- a/quetz/database.py +++ b/quetz/database.py @@ -65,7 +65,9 @@ def on_close(dbapi_conn, conn_record): def get_session_maker( bind: sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, ) -> Callable[[], sessionmaker]: - return sessionmaker(autocommit=False, autoflush=True, bind=bind) + return sessionmaker( + autocommit=False, autoflush=True, bind=bind, expire_on_commit=False + ) def get_session(config: Config | None) -> Session: diff --git a/quetz/tasks/workers.py b/quetz/tasks/workers.py index 2b98b61a..d1b9bdfa 100644 --- a/quetz/tasks/workers.py +++ b/quetz/tasks/workers.py @@ -140,6 +140,7 @@ def job_wrapper( auth = kwargs.pop("auth", None) session = kwargs.pop("session", None) + kwargs.pop("db", None) with get_session(config) as db: user_id: Optional[str] if task_id: diff --git a/quetz/tests/conftest.py b/quetz/tests/conftest.py index 7905d4f9..a907598f 100644 --- a/quetz/tests/conftest.py +++ b/quetz/tests/conftest.py @@ -45,13 +45,16 @@ def user(db, user_without_profile): ) db.add(profile) db.commit() + profile_name = profile.name + profile_avatar_url = profile.avatar_url + profile_user_id = user_without_profile.id yield user_without_profile db.query(Profile).filter( - Profile.name == profile.name, - Profile.avatar_url == profile.avatar_url, - Profile.user_id == user_without_profile.id, + Profile.name == profile_name, + Profile.avatar_url == profile_avatar_url, + Profile.user_id == profile_user_id, ).delete() db.commit() diff --git a/quetz/tests/test_auth.py b/quetz/tests/test_auth.py index 6a8c6a0c..977e6cb8 100644 --- a/quetz/tests/test_auth.py +++ b/quetz/tests/test_auth.py @@ -71,7 +71,7 @@ def __init__(self, db): package_format="tarbz2", platform="noarch", version="0.0.1", - build_number="0", + build_number=0, build_string="", filename="filename.tar.bz2", info="{}", @@ -85,7 +85,7 @@ def __init__(self, db): package_format="tarbz2", platform="noarch", version="0.0.1", - build_number="0", + build_number=0, build_string="", filename="filename2.tar.bz2", info="{}", diff --git a/quetz/tests/test_dao.py b/quetz/tests/test_dao.py index 31f57a75..788ca54e 100644 --- a/quetz/tests/test_dao.py +++ b/quetz/tests/test_dao.py @@ -46,7 +46,7 @@ def package_version(dao, package, user): package_format="tarbz2", platform="noarch", version="0.0.1", - build_number="0", + build_number=0, build_string="", filename="filename.tar.bz2", info="{}", @@ -84,7 +84,7 @@ def test_create_version(dao, package, channel_name, package_name, db, user): package_format="tarbz2", platform="noarch", version="0.0.1", - build_number="0", + build_number=0, build_string="", filename="filename.tar.bz2", info="{}", @@ -114,7 +114,7 @@ def test_create_version(dao, package, channel_name, package_name, db, user): package_format="tarbz2", platform="noarch", version="0.0.1", - build_number="0", + build_number=0, build_string="", filename="filename-2.tar.bz2", info="{}", @@ -130,7 +130,7 @@ def test_create_version(dao, package, channel_name, package_name, db, user): package_format="tarbz2", platform="noarch", version="0.0.1", - build_number="0", + build_number=0, build_string="", filename="filename-2.tar.bz2", info='{"version": "x.y.z"}', From daaf43c2f6f78438102c9547646161596c9f5964 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 15 Jul 2024 10:15:25 +0200 Subject: [PATCH 12/24] fix --- plugins/quetz_conda_suggest/tests/conftest.py | 2 +- .../tests/test_main.py | 2 +- quetz/dao.py | 6 ++-- quetz/tasks/mirror.py | 28 +++++++++---------- quetz/tests/api/conftest.py | 2 +- quetz/tests/test_mirror.py | 2 +- 6 files changed, 22 insertions(+), 20 deletions(-) diff --git a/plugins/quetz_conda_suggest/tests/conftest.py b/plugins/quetz_conda_suggest/tests/conftest.py index 666ab01a..d6436a1f 100644 --- a/plugins/quetz_conda_suggest/tests/conftest.py +++ b/plugins/quetz_conda_suggest/tests/conftest.py @@ -77,7 +77,7 @@ def package_version(user, channel, db, dao, package): package_format, "linux-64", "0.1", - "0", + 0, "0", "test-package-0.1-0.tar.bz2", package_info, diff --git a/plugins/quetz_repodata_patching/tests/test_main.py b/plugins/quetz_repodata_patching/tests/test_main.py index 8be5b873..39f91765 100644 --- a/plugins/quetz_repodata_patching/tests/test_main.py +++ b/plugins/quetz_repodata_patching/tests/test_main.py @@ -92,7 +92,7 @@ def package_version( package_format, package_subdir, "0.1", - "0", + 0, "0", package_file_name, package_info, diff --git a/quetz/dao.py b/quetz/dao.py index d5f851da..19d9ff2a 100644 --- a/quetz/dao.py +++ b/quetz/dao.py @@ -20,7 +20,6 @@ from quetz import channel_data, errors, rest_models, versionorder from quetz.database_extensions import version_match from quetz.utils import apply_custom_query -from .condainfo import CondaInfo from .db_models import ( ApiKey, @@ -819,11 +818,14 @@ def create_version( build_number: int, build_string: str, filename: str, - info: CondaInfo, + info: str, uploader_id: bytes, size: int, upsert: bool = False, ): + if not isinstance(build_number, int): + raise TypeError("build_number should be an integer") + # hold a lock on the package package = ( # noqa self.db.query(Package) diff --git a/quetz/tasks/mirror.py b/quetz/tasks/mirror.py index ab65e145..3bd1d2e4 100644 --- a/quetz/tasks/mirror.py +++ b/quetz/tasks/mirror.py @@ -19,7 +19,7 @@ from quetz.condainfo import CondaInfo, get_subdir_compat from quetz.config import Config from quetz.dao import Dao -from quetz.db_models import PackageVersion +from quetz.db_models import PackageVersion, PackageFormatEnum from quetz.errors import DBError from quetz.pkgstores import PackageStore from quetz.tasks import indexing @@ -489,26 +489,26 @@ def create_version_from_metadata( dao.create_package(channel_name, package_info, user_id, "owner") if package_file_name.endswith(".conda"): - pkg_format = "conda" + pkg_format = PackageFormatEnum.conda elif package_file_name.endswith(".tar.bz2"): - pkg_format = "tarbz2" + pkg_format = PackageFormatEnum.tarbz2 else: raise ValueError( f"Unknown package format for package {package_file_name}" f"in channel {channel_name}" ) version = dao.create_version( - channel_name, - package_name, - pkg_format, - get_subdir_compat(package_data), - package_data["version"], - int(package_data["build_number"]), - package_data["build"], - package_file_name, - json.dumps(package_data), - user_id, - package_data["size"], + channel_name=channel_name, + package_name=package_name, + package_format=pkg_format, + platform=get_subdir_compat(package_data), + version=package_data["version"], + build_number=int(package_data["build_number"]), + build_string=package_data["build"], + filename=package_file_name, + info=json.dumps(package_data), + uploader_id=user_id, + size=package_data["size"], ) return version diff --git a/quetz/tests/api/conftest.py b/quetz/tests/api/conftest.py index b5ec184c..a0b667c4 100644 --- a/quetz/tests/api/conftest.py +++ b/quetz/tests/api/conftest.py @@ -48,7 +48,7 @@ def private_package_version( package_format, platform, "0.1", - "0", + 0, "", str(filename), package_info, diff --git a/quetz/tests/test_mirror.py b/quetz/tests/test_mirror.py index e4a95485..cde2f540 100644 --- a/quetz/tests/test_mirror.py +++ b/quetz/tests/test_mirror.py @@ -282,7 +282,7 @@ def package_version(user, mirror_channel, db, dao): package_format, "noarch", "0.1", - "0", + 0, "", "test-package-0.1-0.tar.bz2", package_info, From c2c28f5f6456a9943ad7a400de93c83bcae449f1 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 15 Jul 2024 10:21:22 +0200 Subject: [PATCH 13/24] fix --- quetz/testing/fixtures.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/quetz/testing/fixtures.py b/quetz/testing/fixtures.py index 0f042d7b..6493666a 100644 --- a/quetz/testing/fixtures.py +++ b/quetz/testing/fixtures.py @@ -137,7 +137,18 @@ def session_maker( trans = sql_connection.begin() sql_connection.name = "sqlite-test" - yield get_session_maker(sql_connection) + + session_maker = get_session_maker(sql_connection) + + def wrapper(*args, **kwargs): + """ + Wrapper function that accepts and ignores args / kwargs + to allow for mocking of database.get_session, which accepts + a Config object in the real implementation. + """ + return session_maker() + + yield wrapper if trans is not None: trans.rollback() From 9f62c3b452fb198e023faad504d02061e48cbc8c Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 15 Jul 2024 10:57:46 +0200 Subject: [PATCH 14/24] fix --- quetz/dao.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/quetz/dao.py b/quetz/dao.py index 19d9ff2a..24235ca4 100644 --- a/quetz/dao.py +++ b/quetz/dao.py @@ -910,16 +910,11 @@ def create_version( ) elif upsert: - existing_versions.update( - { - "filename": filename, - "info": info, - "uploader_id": uploader_id, - "time_modified": datetime.utcnow(), - "size": size, - }, - synchronize_session="evaluate", - ) + package_version.filename = filename + package_version.info = info + package_version.uploader_id = uploader_id + package_version.time_modified = datetime.utcnow() + package_version.size = size else: raise IntegrityError("duplicate package version", "", "") From 301943a7813be33cb4b4516bb348d37b9edcb9ac Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 15 Jul 2024 11:47:13 +0200 Subject: [PATCH 15/24] fix --- quetz/dao.py | 53 +++++++++++++++++++++-------------------- quetz/tests/test_dao.py | 14 ++++++++++- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/quetz/dao.py b/quetz/dao.py index 24235ca4..91b13f24 100644 --- a/quetz/dao.py +++ b/quetz/dao.py @@ -9,8 +9,11 @@ from itertools import groupby from typing import TYPE_CHECKING, Dict, List, Optional -from sqlalchemy import and_, func, insert, or_ +import sqlalchemy +from sqlalchemy import and_, func, or_ from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.dialects.sqlite import insert as sqlite_insert + from sqlalchemy.exc import IntegrityError, NoResultFound # type: ignore from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import Query, Session, aliased, joinedload @@ -127,28 +130,19 @@ def upsert_pg(element, compiler, **kw): @compiles(Upsert, "sqlite") -def upsert_sql(element, compiler, **kw): - # on_conflict_do_update does exist in sqlite - # but it was ported to sqlalchemy only in version 1.4 - # which was not released at the time of implementing this - # so we treat it with raw SQL syntax - # sqlite ref: https://www.sqlite.org/lang_upsert.html - # sqlalchemy 1.4 ref: https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#insert-on-conflict-upsert # noqa - +def upsert_sqlite(element, compiler, **kw): index_elements = element.index_elements values = element.values column = element.column incr = element.incr table = element.table - - stmt = insert(table).values(values) - raw_sql = compiler.process(stmt) - upsert_stmt = ( - f"ON CONFLICT ({','.join(index_elements)}) " - f"DO UPDATE SET {column.name}={column.name}+{incr}" + stmt = sqlite_insert(table).values(values) + stmt = stmt.on_conflict_do_update( + index_elements=index_elements, + set_={column.name: column + incr}, ) - return raw_sql + " " + upsert_stmt + return compiler.visit_insert(stmt) def get_paginated_result(query: Query, skip: int, limit: int): @@ -472,9 +466,11 @@ def delete_channel_mirror(self, channel_name: str, mirror_id: str): self.db.commit() def update_channel(self, channel_name, data: dict): - self.db.query(Channel).filter(Channel.name == channel_name).update( - data, synchronize_session=False - ) + if data.get("name", channel_name) != channel_name: + raise ValueError("channel_name cannot be changed") + full_data = dict(**data, name=channel_name) + self.db.execute(sqlalchemy.update(Channel), [full_data]) + self.db.commit() def delete_channel(self, channel_name): @@ -1183,12 +1179,17 @@ def incr_download_count( incr: int = 1, ): metric_name = "download" - - self.db.query(PackageVersion).filter( - PackageVersion.channel_name == channel - ).filter(PackageVersion.filename == filename).filter( - PackageVersion.platform == platform - ).update({PackageVersion.download_count: PackageVersion.download_count + incr}) + self.db.execute( + sqlalchemy.update(PackageVersion) + .where( + PackageVersion.channel_name == channel, + PackageVersion.filename == filename, + PackageVersion.platform == platform, + ) + .values( + download_count=PackageVersion.download_count + incr, + ) + ) if timestamp is None: timestamp = datetime.utcnow() @@ -1215,7 +1216,7 @@ def incr_download_count( "period", "timestamp", ] - + logging.getLogger("sqlalchemy").setLevel(logging.DEBUG) stmt = Upsert( PackageVersionMetric.__table__, all_values, diff --git a/quetz/tests/test_dao.py b/quetz/tests/test_dao.py index 788ca54e..c032d5fd 100644 --- a/quetz/tests/test_dao.py +++ b/quetz/tests/test_dao.py @@ -173,33 +173,43 @@ def test_update_channel_size(dao, channel, db, package_version): def test_increment_download_count( dao: Dao, channel: Channel, db: Session, package_version: PackageVersion ): + # Arrange: Create new package version that was never downloaded assert package_version.download_count == 0 now = datetime.datetime(2020, 10, 1, 10, 1, 10) + + # Act: Increment download count dao.incr_download_count( channel.name, package_version.filename, package_version.platform, timestamp=now ) + # Assert: Download count is incremented in PackageVersionMetric table download_counts = db.query(PackageVersionMetric).all() for m in download_counts: assert m.count == 1 - assert len(download_counts) == len(IntervalType) + # Assert: Download count is incremented on the PackageVersion object itself db.refresh(package_version) assert package_version.download_count == 1 + # Act: Increment download count again dao.incr_download_count( channel.name, package_version.filename, package_version.platform, timestamp=now ) + + # Assert: Download count is incremented in PackageVersionMetric table download_counts = db.query(PackageVersionMetric).all() for m in download_counts: assert m.count == 2 assert len(download_counts) == len(IntervalType) + # Assert: Download count is incremented on the PackageVersion object itself db.refresh(package_version) assert package_version.download_count == 2 + # Act: Increment download count again, + # but this time with a time stamp shifted by one day dao.incr_download_count( channel.name, package_version.filename, @@ -207,6 +217,8 @@ def test_increment_download_count( timestamp=now + datetime.timedelta(days=1), ) + # Assert + # This time, two new metrics are created (intervals H and D) download_counts = db.query(PackageVersionMetric).all() assert len(download_counts) == len(IntervalType) + 2 From f22e5b31fdc6b9e94a9f3d5043251617609b5f8f Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 15 Jul 2024 14:45:52 +0200 Subject: [PATCH 16/24] fix --- quetz/database.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/quetz/database.py b/quetz/database.py index f0d18015..0ba97cf8 100644 --- a/quetz/database.py +++ b/quetz/database.py @@ -66,7 +66,7 @@ def get_session_maker( bind: sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, ) -> Callable[[], sessionmaker]: return sessionmaker( - autocommit=False, autoflush=True, bind=bind, expire_on_commit=False + autocommit=False, autoflush=True, bind=bind, expire_on_commit=True ) From 66e5b3462309eaef14c1ea8dc45234dff02ca92f Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 15 Jul 2024 14:57:08 +0200 Subject: [PATCH 17/24] fix --- quetz/testing/fixtures.py | 13 ++++--------- quetz/tests/test_mirror.py | 7 ++++--- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/quetz/testing/fixtures.py b/quetz/testing/fixtures.py index 6493666a..d7d27a19 100644 --- a/quetz/testing/fixtures.py +++ b/quetz/testing/fixtures.py @@ -2,6 +2,7 @@ import shutil import tempfile from typing import List, Iterator +from unittest import mock import pytest import sqlalchemy.orm @@ -158,8 +159,9 @@ def wrapper(*args, **kwargs): def db( session_maker: sqlalchemy.orm.sessionmaker, ) -> Iterator[sqlalchemy.orm.Session]: - with session_maker() as db: - yield db + with mock.patch("quetz.database.get_session", session_maker): + with session_maker() as db: + yield db @pytest.fixture @@ -257,13 +259,6 @@ def app(config, db, mocker): from quetz.deps import get_db from quetz.main import app - # mocking is required for some functions that do not use fastapi - # dependency injection (mainly non-request functions) - def get_session_mock(*args, **kwargs): - return db - - mocker.patch("quetz.database.get_session", get_session_mock) - # overriding dependency works with all requests handlers that # depend on quetz.deps.get_db app.dependency_overrides[get_db] = lambda: db diff --git a/quetz/tests/test_mirror.py b/quetz/tests/test_mirror.py index cde2f540..38e1043e 100644 --- a/quetz/tests/test_mirror.py +++ b/quetz/tests/test_mirror.py @@ -4,6 +4,7 @@ import uuid from io import BytesIO from pathlib import Path +from unittest import mock from unittest.mock import MagicMock from urllib.parse import urlparse @@ -820,7 +821,7 @@ def test_add_and_register_mirror(auth_client, dummy_session_mock): ] ], ) -def test_wrong_package_format(client, dummy_repo, owner, job_supervisor): +def test_wrong_package_format(session_maker, client, dummy_repo, owner, job_supervisor): response = client.get("/api/dummylogin/bartosz") assert response.status_code == 200 @@ -837,8 +838,8 @@ def test_wrong_package_format(client, dummy_repo, owner, job_supervisor): ) assert response.status_code == 201 - - job_supervisor.run_once() + with mock.patch("quetz.database.get_session", session_maker): + job_supervisor.run_once() assert dummy_repo == [ "http://mirror3_host/channeldata.json", From e2f494ff71b10a74f938b2c186511838eab0e9da Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 15 Jul 2024 15:15:50 +0200 Subject: [PATCH 18/24] fix --- quetz/tasks/workers.py | 2 +- quetz/tests/test_jobs.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/quetz/tasks/workers.py b/quetz/tasks/workers.py index d1b9bdfa..22aab922 100644 --- a/quetz/tasks/workers.py +++ b/quetz/tasks/workers.py @@ -139,7 +139,7 @@ def job_wrapper( pkgstore = kwargs.pop("pkgstore", None) auth = kwargs.pop("auth", None) session = kwargs.pop("session", None) - + kwargs.pop("dao", None) kwargs.pop("db", None) with get_session(config) as db: user_id: Optional[str] diff --git a/quetz/tests/test_jobs.py b/quetz/tests/test_jobs.py index b79ae37d..8261ba95 100644 --- a/quetz/tests/test_jobs.py +++ b/quetz/tests/test_jobs.py @@ -866,9 +866,13 @@ def sync_supervisor(db, dao, config): @pytest.fixture def mock_action(mocker): - func = mocker.Mock() + m = mocker.Mock() + + def func(*args, **kwargs): + m(*args, **kwargs) + mocker.patch("quetz.jobs.handlers.JOB_HANDLERS", {"test_action": func}) - return func + return m def test_update_job_status(sync_supervisor, db, action_job): From 6561e496d0af503e4febb92a96f601f758957707 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 15 Jul 2024 16:07:50 +0200 Subject: [PATCH 19/24] fix? --- quetz/testing/fixtures.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/quetz/testing/fixtures.py b/quetz/testing/fixtures.py index d7d27a19..9603de15 100644 --- a/quetz/testing/fixtures.py +++ b/quetz/testing/fixtures.py @@ -93,14 +93,14 @@ def sql_connection(engine): @pytest.fixture def alembic_config(database_url, sql_connection): alembic_config = _alembic_config(database_url) - alembic_config.attributes["connection"] = sql_connection + alembic_config.attributes["engine"] = sql_connection.engine return alembic_config @pytest.fixture def create_tables(alembic_config, engine, use_migrations): if use_migrations: - alembic_upgrade(alembic_config, "heads", sql=False) + alembic_upgrade(alembic_config, "head", sql=False) else: Base.metadata.create_all(engine) From 6f463802f924c8c9829239f5aff9ef1fbecaaf32 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 15 Jul 2024 17:38:29 +0200 Subject: [PATCH 20/24] Revert "fix?" This reverts commit 6561e496d0af503e4febb92a96f601f758957707. --- quetz/testing/fixtures.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/quetz/testing/fixtures.py b/quetz/testing/fixtures.py index 9603de15..d7d27a19 100644 --- a/quetz/testing/fixtures.py +++ b/quetz/testing/fixtures.py @@ -93,14 +93,14 @@ def sql_connection(engine): @pytest.fixture def alembic_config(database_url, sql_connection): alembic_config = _alembic_config(database_url) - alembic_config.attributes["engine"] = sql_connection.engine + alembic_config.attributes["connection"] = sql_connection return alembic_config @pytest.fixture def create_tables(alembic_config, engine, use_migrations): if use_migrations: - alembic_upgrade(alembic_config, "head", sql=False) + alembic_upgrade(alembic_config, "heads", sql=False) else: Base.metadata.create_all(engine) From 8cb09d45aab732e1d45e189db82f80cb102b4e65 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 15 Jul 2024 18:20:24 +0200 Subject: [PATCH 21/24] ?! --- quetz/testing/fixtures.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/quetz/testing/fixtures.py b/quetz/testing/fixtures.py index d7d27a19..fa4537fe 100644 --- a/quetz/testing/fixtures.py +++ b/quetz/testing/fixtures.py @@ -17,17 +17,17 @@ from quetz.db_models import Base -def pytest_configure(config): - pytest.quetz_variables = { - var: value for var, value in os.environ.items() if var.startswith("QUETZ_") - } - for var in pytest.quetz_variables: - del os.environ[var] +# def pytest_configure(config): +# pytest.quetz_variables = { +# var: value for var, value in os.environ.items() if var.startswith("QUETZ_") +# } +# for var in pytest.quetz_variables: +# del os.environ[var] -def pytest_unconfigure(config): - for var, value in pytest.quetz_variables.items(): - os.environ[var] = value +# def pytest_unconfigure(config): +# for var, value in pytest.quetz_variables.items(): +# os.environ[var] = value @pytest.fixture From 5e25393f98cac51617a13a50d4489d9ae4cb61e5 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Tue, 16 Jul 2024 09:53:56 +0200 Subject: [PATCH 22/24] ?! --- .github/workflows/ci.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index c83a0007..6926830d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -15,7 +15,7 @@ env: jobs: test_quetz: # timeout for the whole job - timeout-minutes: 10 + timeout-minutes: 15 runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -79,7 +79,7 @@ jobs: - name: Testing server shell: bash -l -eo pipefail {0} # timeout for the step - timeout-minutes: 5 + timeout-minutes: 15 env: TEST_DB_BACKEND: ${{ matrix.test_database }} QUETZ_TEST_DBINIT: ${{ matrix.db_init }} From 020c2bff2ca797c0b40f58148667d2412211503d Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 22 Jul 2024 10:38:57 +0200 Subject: [PATCH 23/24] ?! --- quetz/testing/fixtures.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/quetz/testing/fixtures.py b/quetz/testing/fixtures.py index fa4537fe..d7d27a19 100644 --- a/quetz/testing/fixtures.py +++ b/quetz/testing/fixtures.py @@ -17,17 +17,17 @@ from quetz.db_models import Base -# def pytest_configure(config): -# pytest.quetz_variables = { -# var: value for var, value in os.environ.items() if var.startswith("QUETZ_") -# } -# for var in pytest.quetz_variables: -# del os.environ[var] +def pytest_configure(config): + pytest.quetz_variables = { + var: value for var, value in os.environ.items() if var.startswith("QUETZ_") + } + for var in pytest.quetz_variables: + del os.environ[var] -# def pytest_unconfigure(config): -# for var, value in pytest.quetz_variables.items(): -# os.environ[var] = value +def pytest_unconfigure(config): + for var, value in pytest.quetz_variables.items(): + os.environ[var] = value @pytest.fixture From a04dc699ed1ec303735fea5c35089d9fcc813251 Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 22 Jul 2024 11:30:39 +0200 Subject: [PATCH 24/24] fix --- quetz/database.py | 4 +--- quetz/testing/fixtures.py | 8 +++++--- quetz/tests/api/test_users.py | 12 ++++++------ quetz/tests/test_workers.py | 21 +++++++++++++-------- 4 files changed, 25 insertions(+), 20 deletions(-) diff --git a/quetz/database.py b/quetz/database.py index 0ba97cf8..b4809cb7 100644 --- a/quetz/database.py +++ b/quetz/database.py @@ -65,9 +65,7 @@ def on_close(dbapi_conn, conn_record): def get_session_maker( bind: sqlalchemy.engine.Engine | sqlalchemy.engine.Connection, ) -> Callable[[], sessionmaker]: - return sessionmaker( - autocommit=False, autoflush=True, bind=bind, expire_on_commit=True - ) + return sessionmaker(autocommit=False, bind=bind) def get_session(config: Config | None) -> Session: diff --git a/quetz/testing/fixtures.py b/quetz/testing/fixtures.py index d7d27a19..0e0a4a78 100644 --- a/quetz/testing/fixtures.py +++ b/quetz/testing/fixtures.py @@ -19,7 +19,9 @@ def pytest_configure(config): pytest.quetz_variables = { - var: value for var, value in os.environ.items() if var.startswith("QUETZ_") + var: value + for var, value in os.environ.items() + if var.startswith("QUETZ_") and not var.startswith("QUETZ_TEST") } for var in pytest.quetz_variables: del os.environ[var] @@ -158,10 +160,10 @@ def wrapper(*args, **kwargs): @pytest.fixture def db( session_maker: sqlalchemy.orm.sessionmaker, -) -> Iterator[sqlalchemy.orm.Session]: +) -> sqlalchemy.orm.Session: with mock.patch("quetz.database.get_session", session_maker): with session_maker() as db: - yield db + return db @pytest.fixture diff --git a/quetz/tests/api/test_users.py b/quetz/tests/api/test_users.py index 4f3473bd..9a643039 100644 --- a/quetz/tests/api/test_users.py +++ b/quetz/tests/api/test_users.py @@ -176,7 +176,6 @@ def test_delete_user_permission( other_user, auth_client, db, user_role, target_user, expected_status, user, api_keys ): response = auth_client.delete(f"/api/users/{target_user}") - deleted_user = db.query(User).filter(User.username == target_user).one_or_none() if expected_status == 200: @@ -196,11 +195,12 @@ def test_delete_user_permission( assert deleted_user.api_keys_user # check if other users were not accidently removed - existing_user = db.query(User).filter(User.username != target_user).one_or_none() - assert existing_user - assert existing_user.profile - assert existing_user.api_keys_owner - assert existing_user.api_keys_user + existing_users = db.query(User).filter(User.username != target_user).all() + assert existing_users + for existing_user in existing_users: + assert existing_user.profile + assert existing_user.api_keys_owner + assert existing_user.api_keys_user @pytest.mark.parametrize("user_role", ["owner"]) diff --git a/quetz/tests/test_workers.py b/quetz/tests/test_workers.py index b8a1353f..c5436eea 100644 --- a/quetz/tests/test_workers.py +++ b/quetz/tests/test_workers.py @@ -8,6 +8,7 @@ from quetz.authorization import Rules from quetz.dao import Dao +from quetz.database import get_session_maker from quetz.db_models import User from quetz.tasks.workers import RQManager, SubprocessWorker, ThreadingWorker @@ -124,19 +125,23 @@ def function_with_dao(dao: Dao): @pytest.fixture -def db_cleanup(config): +def db_cleanup(engine): # we can't use the db fixture for cleaning up because # it automatically rollsback all operations yield - from quetz.database import get_session - - with get_session(config) as db: - user = db.query(User).one_or_none() - if user: - db.delete(user) - db.commit() + with engine.connect() as con: + session_maker = get_session_maker(con) + with session_maker() as db: + user = db.query(User).one_or_none() + if user: + db.delete(user) + db.commit() + + with session_maker() as db: + user = db.query(User).one_or_none() + assert user is None @pytest.mark.asyncio