diff --git a/open_prices/api/prices/filters.py b/open_prices/api/prices/filters.py index c8835a34..f0e00108 100644 --- a/open_prices/api/prices/filters.py +++ b/open_prices/api/prices/filters.py @@ -7,6 +7,12 @@ class PriceFilter(django_filters.FilterSet): product_id__isnull = django_filters.BooleanFilter( field_name="product_id", lookup_expr="isnull" ) + product_labels_tags__contains = django_filters.CharFilter( + field_name="product__labels_tags", lookup_expr="icontains" + ) + product_categories_tags__contains = django_filters.CharFilter( + field_name="product__categories_tags", lookup_expr="icontains" + ) labels_tags__contains = django_filters.CharFilter( field_name="labels_tags", lookup_expr="icontains" ) @@ -52,4 +58,5 @@ class Meta: "date", "proof_id", "owner", + "proof__type", ] diff --git a/open_prices/api/prices/serializers.py b/open_prices/api/prices/serializers.py index ba8dea2f..5aad9d4a 100644 --- a/open_prices/api/prices/serializers.py +++ b/open_prices/api/prices/serializers.py @@ -64,3 +64,26 @@ class PriceStatsSerializer(serializers.Serializer): price__min = serializers.DecimalField(max_digits=10, decimal_places=2) price__max = serializers.DecimalField(max_digits=10, decimal_places=2) price__avg = serializers.DecimalField(max_digits=10, decimal_places=2) + price__sum = serializers.DecimalField(max_digits=10, decimal_places=2) + + +class GroupedPriceStatsQuerySerializer(serializers.Serializer): + group_by = serializers.CharField( + required=True, help_text="Field by which to group the statistics" + ) + order_by = serializers.CharField( + required=False, help_text="Field by which to order the results" + ) + + +class GroupedPriceStatsSerializer(PriceStatsSerializer): + # Override representation to dynamically include the group field + def to_representation(self, instance): + representation = super().to_representation(instance) + + # Add the grouping field dynamically + for key in instance: + if key not in representation: # It's likely the group field + representation[key] = instance[key] + + return representation diff --git a/open_prices/api/prices/views.py b/open_prices/api/prices/views.py index 2d1d79ed..98e6b11b 100644 --- a/open_prices/api/prices/views.py +++ b/open_prices/api/prices/views.py @@ -1,5 +1,6 @@ +from django.core.exceptions import FieldError from django_filters.rest_framework import DjangoFilterBackend -from drf_spectacular.utils import extend_schema +from drf_spectacular.utils import OpenApiParameter, extend_schema from rest_framework import filters, mixins, status, viewsets from rest_framework.decorators import action from rest_framework.permissions import IsAuthenticatedOrReadOnly @@ -8,6 +9,8 @@ from open_prices.api.prices.filters import PriceFilter from open_prices.api.prices.serializers import ( + GroupedPriceStatsQuerySerializer, + GroupedPriceStatsSerializer, PriceCreateSerializer, PriceFullSerializer, PriceStatsSerializer, @@ -79,3 +82,41 @@ def create(self, request: Request, *args, **kwargs): def stats(self, request: Request) -> Response: qs = self.filter_queryset(self.get_queryset()) return Response(qs.calculate_stats(), status=200) + + @extend_schema( + request=GroupedPriceStatsQuerySerializer, + responses=GroupedPriceStatsSerializer(many=True), + filters=True, + parameters=[ + OpenApiParameter( + name="group_by", + description="Field by which to group the statistics", + required=True, + type=str, + location=OpenApiParameter.QUERY, + ) + ], + ) + @action(detail=False, methods=["GET"]) + def grouped_stats(self, request: Request) -> Response: + qs = self.filter_queryset(self.get_queryset()) + + # Validate and parse query parameters using the serializer + serializer = GroupedPriceStatsQuerySerializer(data=request.query_params) + serializer.is_valid(raise_exception=True) + group_by = serializer.validated_data.get("group_by") + order_by = serializer.validated_data.get("order_by", None) + + try: + data = qs.calculate_grouped_stats(group_by, order_by) + except FieldError: + return Response( + {"detail": f"Invalid group_by field: {group_by}"}, + status=status.HTTP_400_BAD_REQUEST, + ) + + # Apply pagination + paginator = self.paginator # Use the default pagination class + paginated_data = paginator.paginate_queryset(data, request, view=self) + + return paginator.get_paginated_response(paginated_data) diff --git a/open_prices/prices/models.py b/open_prices/prices/models.py index a58e2d29..e811ceb1 100644 --- a/open_prices/prices/models.py +++ b/open_prices/prices/models.py @@ -3,8 +3,8 @@ from django.core.validators import MinValueValidator, ValidationError from django.db import models -from django.db.models import Avg, Count, F, Max, Min, signals -from django.db.models.functions import Cast +from django.db.models import Avg, Count, F, Max, Min, Sum, signals +from django.db.models.functions import Cast, TruncMonth, TruncWeek, TruncYear from django.dispatch import receiver from django.utils import timezone from openfoodfacts.taxonomy import ( @@ -54,6 +54,38 @@ def calculate_stats(self): Avg("price"), output_field=models.DecimalField(max_digits=10, decimal_places=2), ), + price__sum=Sum("price"), + ) + + def calculate_grouped_stats(self, group_by, order_by): + group_by_list = group_by.split(",") + if ( + "month" in group_by_list + or "year" in group_by_list + or "week" in group_by_list + ): + queryset = self.annotate( + month=TruncMonth("date"), year=TruncYear("date"), week=TruncWeek("date") + ) + else: + queryset = self + if order_by: + order_by_list = [order_by] + else: + order_by_list = group_by_list + return ( + queryset.values(*group_by_list) + .annotate( + price__count=Count("pk"), + price__min=Min("price"), + price__max=Max("price"), + price__avg=Cast( + Avg("price"), + output_field=models.DecimalField(max_digits=10, decimal_places=2), + ), + price__sum=Sum("price"), + ) + .order_by(*order_by_list) )