From 301943a7813be33cb4b4516bb348d37b9edcb9ac Mon Sep 17 00:00:00 2001 From: Andreas Albert Date: Mon, 15 Jul 2024 11:47:13 +0200 Subject: [PATCH] fix --- quetz/dao.py | 53 +++++++++++++++++++++-------------------- quetz/tests/test_dao.py | 14 ++++++++++- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/quetz/dao.py b/quetz/dao.py index 24235ca4..91b13f24 100644 --- a/quetz/dao.py +++ b/quetz/dao.py @@ -9,8 +9,11 @@ from itertools import groupby from typing import TYPE_CHECKING, Dict, List, Optional -from sqlalchemy import and_, func, insert, or_ +import sqlalchemy +from sqlalchemy import and_, func, or_ from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.dialects.sqlite import insert as sqlite_insert + from sqlalchemy.exc import IntegrityError, NoResultFound # type: ignore from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import Query, Session, aliased, joinedload @@ -127,28 +130,19 @@ def upsert_pg(element, compiler, **kw): @compiles(Upsert, "sqlite") -def upsert_sql(element, compiler, **kw): - # on_conflict_do_update does exist in sqlite - # but it was ported to sqlalchemy only in version 1.4 - # which was not released at the time of implementing this - # so we treat it with raw SQL syntax - # sqlite ref: https://www.sqlite.org/lang_upsert.html - # sqlalchemy 1.4 ref: https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#insert-on-conflict-upsert # noqa - +def upsert_sqlite(element, compiler, **kw): index_elements = element.index_elements values = element.values column = element.column incr = element.incr table = element.table - - stmt = insert(table).values(values) - raw_sql = compiler.process(stmt) - upsert_stmt = ( - f"ON CONFLICT ({','.join(index_elements)}) " - f"DO UPDATE SET {column.name}={column.name}+{incr}" + stmt = sqlite_insert(table).values(values) + stmt = stmt.on_conflict_do_update( + index_elements=index_elements, + set_={column.name: column + incr}, ) - return raw_sql + " " + upsert_stmt + return compiler.visit_insert(stmt) def get_paginated_result(query: Query, skip: int, limit: int): @@ -472,9 +466,11 @@ def delete_channel_mirror(self, channel_name: str, mirror_id: str): self.db.commit() def update_channel(self, channel_name, data: dict): - self.db.query(Channel).filter(Channel.name == channel_name).update( - data, synchronize_session=False - ) + if data.get("name", channel_name) != channel_name: + raise ValueError("channel_name cannot be changed") + full_data = dict(**data, name=channel_name) + self.db.execute(sqlalchemy.update(Channel), [full_data]) + self.db.commit() def delete_channel(self, channel_name): @@ -1183,12 +1179,17 @@ def incr_download_count( incr: int = 1, ): metric_name = "download" - - self.db.query(PackageVersion).filter( - PackageVersion.channel_name == channel - ).filter(PackageVersion.filename == filename).filter( - PackageVersion.platform == platform - ).update({PackageVersion.download_count: PackageVersion.download_count + incr}) + self.db.execute( + sqlalchemy.update(PackageVersion) + .where( + PackageVersion.channel_name == channel, + PackageVersion.filename == filename, + PackageVersion.platform == platform, + ) + .values( + download_count=PackageVersion.download_count + incr, + ) + ) if timestamp is None: timestamp = datetime.utcnow() @@ -1215,7 +1216,7 @@ def incr_download_count( "period", "timestamp", ] - + logging.getLogger("sqlalchemy").setLevel(logging.DEBUG) stmt = Upsert( PackageVersionMetric.__table__, all_values, diff --git a/quetz/tests/test_dao.py b/quetz/tests/test_dao.py index 788ca54e..c032d5fd 100644 --- a/quetz/tests/test_dao.py +++ b/quetz/tests/test_dao.py @@ -173,33 +173,43 @@ def test_update_channel_size(dao, channel, db, package_version): def test_increment_download_count( dao: Dao, channel: Channel, db: Session, package_version: PackageVersion ): + # Arrange: Create new package version that was never downloaded assert package_version.download_count == 0 now = datetime.datetime(2020, 10, 1, 10, 1, 10) + + # Act: Increment download count dao.incr_download_count( channel.name, package_version.filename, package_version.platform, timestamp=now ) + # Assert: Download count is incremented in PackageVersionMetric table download_counts = db.query(PackageVersionMetric).all() for m in download_counts: assert m.count == 1 - assert len(download_counts) == len(IntervalType) + # Assert: Download count is incremented on the PackageVersion object itself db.refresh(package_version) assert package_version.download_count == 1 + # Act: Increment download count again dao.incr_download_count( channel.name, package_version.filename, package_version.platform, timestamp=now ) + + # Assert: Download count is incremented in PackageVersionMetric table download_counts = db.query(PackageVersionMetric).all() for m in download_counts: assert m.count == 2 assert len(download_counts) == len(IntervalType) + # Assert: Download count is incremented on the PackageVersion object itself db.refresh(package_version) assert package_version.download_count == 2 + # Act: Increment download count again, + # but this time with a time stamp shifted by one day dao.incr_download_count( channel.name, package_version.filename, @@ -207,6 +217,8 @@ def test_increment_download_count( timestamp=now + datetime.timedelta(days=1), ) + # Assert + # This time, two new metrics are created (intervals H and D) download_counts = db.query(PackageVersionMetric).all() assert len(download_counts) == len(IntervalType) + 2