Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Price Stats): Stats grouped by #675

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions open_prices/api/prices/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ class PriceFilter(django_filters.FilterSet):
product_id__isnull = django_filters.BooleanFilter(
field_name="product_id", lookup_expr="isnull"
)
product_labels_tags__contains = django_filters.CharFilter(
field_name="product__labels_tags", lookup_expr="icontains"
)
product_categories_tags__contains = django_filters.CharFilter(
field_name="product__categories_tags", lookup_expr="icontains"
)
labels_tags__contains = django_filters.CharFilter(
field_name="labels_tags", lookup_expr="icontains"
)
Expand Down Expand Up @@ -52,4 +58,5 @@ class Meta:
"date",
"proof_id",
"owner",
"proof__type",
]
23 changes: 23 additions & 0 deletions open_prices/api/prices/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,26 @@ class PriceStatsSerializer(serializers.Serializer):
price__min = serializers.DecimalField(max_digits=10, decimal_places=2)
price__max = serializers.DecimalField(max_digits=10, decimal_places=2)
price__avg = serializers.DecimalField(max_digits=10, decimal_places=2)
price__sum = serializers.DecimalField(max_digits=10, decimal_places=2)


class GroupedPriceStatsQuerySerializer(serializers.Serializer):
group_by = serializers.CharField(
required=True, help_text="Field by which to group the statistics"
)
order_by = serializers.CharField(
required=False, help_text="Field by which to order the results"
)


class GroupedPriceStatsSerializer(PriceStatsSerializer):
# Override representation to dynamically include the group field
def to_representation(self, instance):
representation = super().to_representation(instance)

# Add the grouping field dynamically
for key in instance:
if key not in representation: # It's likely the group field
representation[key] = instance[key]

return representation
43 changes: 42 additions & 1 deletion open_prices/api/prices/views.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from django.core.exceptions import FieldError
from django_filters.rest_framework import DjangoFilterBackend
from drf_spectacular.utils import extend_schema
from drf_spectacular.utils import OpenApiParameter, extend_schema
from rest_framework import filters, mixins, status, viewsets
from rest_framework.decorators import action
from rest_framework.permissions import IsAuthenticatedOrReadOnly
Expand All @@ -8,6 +9,8 @@

from open_prices.api.prices.filters import PriceFilter
from open_prices.api.prices.serializers import (
GroupedPriceStatsQuerySerializer,
GroupedPriceStatsSerializer,
PriceCreateSerializer,
PriceFullSerializer,
PriceStatsSerializer,
Expand Down Expand Up @@ -79,3 +82,41 @@ def create(self, request: Request, *args, **kwargs):
def stats(self, request: Request) -> Response:
qs = self.filter_queryset(self.get_queryset())
return Response(qs.calculate_stats(), status=200)

@extend_schema(
request=GroupedPriceStatsQuerySerializer,
responses=GroupedPriceStatsSerializer(many=True),
filters=True,
parameters=[
OpenApiParameter(
name="group_by",
description="Field by which to group the statistics",
required=True,
type=str,
location=OpenApiParameter.QUERY,
)
],
)
@action(detail=False, methods=["GET"])
def grouped_stats(self, request: Request) -> Response:
qs = self.filter_queryset(self.get_queryset())

# Validate and parse query parameters using the serializer
serializer = GroupedPriceStatsQuerySerializer(data=request.query_params)
serializer.is_valid(raise_exception=True)
group_by = serializer.validated_data.get("group_by")
order_by = serializer.validated_data.get("order_by", None)

try:
data = qs.calculate_grouped_stats(group_by, order_by)
except FieldError:
return Response(
{"detail": f"Invalid group_by field: {group_by}"},
status=status.HTTP_400_BAD_REQUEST,
)

# Apply pagination
paginator = self.paginator # Use the default pagination class
paginated_data = paginator.paginate_queryset(data, request, view=self)

return paginator.get_paginated_response(paginated_data)
36 changes: 34 additions & 2 deletions open_prices/prices/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

from django.core.validators import MinValueValidator, ValidationError
from django.db import models
from django.db.models import Avg, Count, F, Max, Min, signals
from django.db.models.functions import Cast
from django.db.models import Avg, Count, F, Max, Min, Sum, signals
from django.db.models.functions import Cast, TruncMonth, TruncWeek, TruncYear
from django.dispatch import receiver
from django.utils import timezone
from openfoodfacts.taxonomy import (
Expand Down Expand Up @@ -54,6 +54,38 @@ def calculate_stats(self):
Avg("price"),
output_field=models.DecimalField(max_digits=10, decimal_places=2),
),
price__sum=Sum("price"),
)

def calculate_grouped_stats(self, group_by, order_by):
group_by_list = group_by.split(",")
if (
"month" in group_by_list
or "year" in group_by_list
or "week" in group_by_list
):
queryset = self.annotate(
month=TruncMonth("date"), year=TruncYear("date"), week=TruncWeek("date")
)
else:
queryset = self
if order_by:
order_by_list = [order_by]
else:
order_by_list = group_by_list
return (
queryset.values(*group_by_list)
.annotate(
price__count=Count("pk"),
price__min=Min("price"),
price__max=Max("price"),
price__avg=Cast(
Avg("price"),
output_field=models.DecimalField(max_digits=10, decimal_places=2),
),
price__sum=Sum("price"),
)
.order_by(*order_by_list)
)


Expand Down
Loading