Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreasAlbertQC committed Jul 15, 2024
1 parent 9f62c3b commit 301943a
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 27 deletions.
53 changes: 27 additions & 26 deletions quetz/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
from itertools import groupby
from typing import TYPE_CHECKING, Dict, List, Optional

from sqlalchemy import and_, func, insert, or_
import sqlalchemy
from sqlalchemy import and_, func, or_
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert

from sqlalchemy.exc import IntegrityError, NoResultFound # type: ignore
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Query, Session, aliased, joinedload
Expand Down Expand Up @@ -127,28 +130,19 @@ def upsert_pg(element, compiler, **kw):


@compiles(Upsert, "sqlite")
def upsert_sql(element, compiler, **kw):
# on_conflict_do_update does exist in sqlite
# but it was ported to sqlalchemy only in version 1.4
# which was not released at the time of implementing this
# so we treat it with raw SQL syntax
# sqlite ref: https://www.sqlite.org/lang_upsert.html
# sqlalchemy 1.4 ref: https://docs.sqlalchemy.org/en/14/dialects/sqlite.html#insert-on-conflict-upsert # noqa

def upsert_sqlite(element, compiler, **kw):
index_elements = element.index_elements
values = element.values
column = element.column
incr = element.incr
table = element.table

stmt = insert(table).values(values)
raw_sql = compiler.process(stmt)
upsert_stmt = (
f"ON CONFLICT ({','.join(index_elements)}) "
f"DO UPDATE SET {column.name}={column.name}+{incr}"
stmt = sqlite_insert(table).values(values)
stmt = stmt.on_conflict_do_update(
index_elements=index_elements,
set_={column.name: column + incr},
)

return raw_sql + " " + upsert_stmt
return compiler.visit_insert(stmt)


def get_paginated_result(query: Query, skip: int, limit: int):
Expand Down Expand Up @@ -472,9 +466,11 @@ def delete_channel_mirror(self, channel_name: str, mirror_id: str):
self.db.commit()

def update_channel(self, channel_name, data: dict):
self.db.query(Channel).filter(Channel.name == channel_name).update(
data, synchronize_session=False
)
if data.get("name", channel_name) != channel_name:
raise ValueError("channel_name cannot be changed")
full_data = dict(**data, name=channel_name)
self.db.execute(sqlalchemy.update(Channel), [full_data])

self.db.commit()

def delete_channel(self, channel_name):
Expand Down Expand Up @@ -1183,12 +1179,17 @@ def incr_download_count(
incr: int = 1,
):
metric_name = "download"

self.db.query(PackageVersion).filter(
PackageVersion.channel_name == channel
).filter(PackageVersion.filename == filename).filter(
PackageVersion.platform == platform
).update({PackageVersion.download_count: PackageVersion.download_count + incr})
self.db.execute(
sqlalchemy.update(PackageVersion)
.where(
PackageVersion.channel_name == channel,
PackageVersion.filename == filename,
PackageVersion.platform == platform,
)
.values(
download_count=PackageVersion.download_count + incr,
)
)

if timestamp is None:
timestamp = datetime.utcnow()
Expand All @@ -1215,7 +1216,7 @@ def incr_download_count(
"period",
"timestamp",
]

logging.getLogger("sqlalchemy").setLevel(logging.DEBUG)
stmt = Upsert(
PackageVersionMetric.__table__,
all_values,
Expand Down
14 changes: 13 additions & 1 deletion quetz/tests/test_dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,40 +173,52 @@ def test_update_channel_size(dao, channel, db, package_version):
def test_increment_download_count(
dao: Dao, channel: Channel, db: Session, package_version: PackageVersion
):
# Arrange: Create new package version that was never downloaded
assert package_version.download_count == 0
now = datetime.datetime(2020, 10, 1, 10, 1, 10)

# Act: Increment download count
dao.incr_download_count(
channel.name, package_version.filename, package_version.platform, timestamp=now
)

# Assert: Download count is incremented in PackageVersionMetric table
download_counts = db.query(PackageVersionMetric).all()
for m in download_counts:
assert m.count == 1

assert len(download_counts) == len(IntervalType)

# Assert: Download count is incremented on the PackageVersion object itself
db.refresh(package_version)
assert package_version.download_count == 1

# Act: Increment download count again
dao.incr_download_count(
channel.name, package_version.filename, package_version.platform, timestamp=now
)

# Assert: Download count is incremented in PackageVersionMetric table
download_counts = db.query(PackageVersionMetric).all()
for m in download_counts:
assert m.count == 2

Check failure on line 203 in quetz/tests/test_dao.py

View workflow job for this annotation

GitHub Actions / test_quetz (ubuntu-latest, sqlite, create-tables)

test_increment_download_count assert 1 == 2 + where 1 = PackageVersionMetric(metric_name=download, period=H, timestamp=2020-10-01 10:00:00,count=1).count

Check failure on line 203 in quetz/tests/test_dao.py

View workflow job for this annotation

GitHub Actions / test_quetz (ubuntu-latest, postgres, create-tables)

test_increment_download_count assert 1 == 2 + where 1 = PackageVersionMetric(metric_name=download, period=H, timestamp=2020-10-01 10:00:00,count=1).count

Check failure on line 203 in quetz/tests/test_dao.py

View workflow job for this annotation

GitHub Actions / test_quetz (ubuntu-latest, sqlite, use-migrations)

test_increment_download_count assert 1 == 2 + where 1 = PackageVersionMetric(metric_name=download, period=H, timestamp=2020-10-01 10:00:00,count=1).count

Check failure on line 203 in quetz/tests/test_dao.py

View workflow job for this annotation

GitHub Actions / test_quetz (ubuntu-latest, postgres, use-migrations)

test_increment_download_count assert 1 == 2 + where 1 = PackageVersionMetric(metric_name=download, period=H, timestamp=2020-10-01 10:00:00,count=1).count

assert len(download_counts) == len(IntervalType)

# Assert: Download count is incremented on the PackageVersion object itself
db.refresh(package_version)
assert package_version.download_count == 2

# Act: Increment download count again,
# but this time with a time stamp shifted by one day
dao.incr_download_count(
channel.name,
package_version.filename,
package_version.platform,
timestamp=now + datetime.timedelta(days=1),
)

# Assert
# This time, two new metrics are created (intervals H and D)
download_counts = db.query(PackageVersionMetric).all()
assert len(download_counts) == len(IntervalType) + 2

Expand Down

0 comments on commit 301943a

Please sign in to comment.