Skip to content

Commit

Permalink
API Improvements: Fix typings for various files pt5 (#1087)
Browse files Browse the repository at this point in the history
  • Loading branch information
JerrySentry authored Jan 10, 2025
1 parent 8660eab commit 47a0933
Show file tree
Hide file tree
Showing 15 changed files with 131 additions and 87 deletions.
10 changes: 6 additions & 4 deletions api/internal/feature/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import pickle
from typing import Any, Dict, List

from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from shared.django_apps.rollouts.models import FeatureFlag
Expand All @@ -20,15 +22,15 @@ class FeaturesView(APIView):
skip_feature_cache = get_config("setup", "skip_feature_cache", default=False)
timeout = 300

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.redis = get_redis_connection()
super().__init__(*args, **kwargs)

def get_many_from_redis(self, keys):
def get_many_from_redis(self, keys: List) -> Dict[str, Any]:
ret = self.redis.mget(keys)
return {k: pickle.loads(v) for k, v in zip(keys, ret) if v is not None}

def set_many_to_redis(self, data):
def set_many_to_redis(self, data: Dict[str, Any]) -> None:
pipeline = self.redis.pipeline()
pipeline.mset({k: pickle.dumps(v) for k, v in data.items()})

Expand All @@ -38,7 +40,7 @@ def set_many_to_redis(self, data):
pipeline.expire(key, self.timeout)
pipeline.execute()

def post(self, request):
def post(self, request: Request) -> Response:
serializer = FeatureRequestSerializer(data=request.data)
if serializer.is_valid():
flag_evaluations = {}
Expand Down
39 changes: 20 additions & 19 deletions api/internal/owner/serializers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from datetime import datetime
from typing import Any, Dict

from dateutil.relativedelta import relativedelta
from django.conf import settings
Expand Down Expand Up @@ -37,7 +38,7 @@ class Meta:

read_only_fields = fields

def get_stats(self, obj):
def get_stats(self, obj: Owner) -> str | None:
if obj.cache and "stats" in obj.cache:
return obj.cache["stats"]

Expand All @@ -50,7 +51,7 @@ class StripeLineItemSerializer(serializers.Serializer):
plan_name = serializers.SerializerMethodField()
quantity = serializers.IntegerField()

def get_plan_name(self, line_item):
def get_plan_name(self, line_item: Dict[str, str]) -> str | None:
plan = line_item.get("plan")
if plan:
return plan.get("name")
Expand Down Expand Up @@ -85,7 +86,7 @@ class StripeDiscountSerializer(serializers.Serializer):
duration_in_months = serializers.IntegerField(source="coupon.duration_in_months")
expires = serializers.SerializerMethodField()

def get_expires(self, customer):
def get_expires(self, customer: Dict[str, Dict]) -> int | None:
coupon = customer.get("coupon")
if coupon:
months = coupon.get("duration_in_months")
Expand Down Expand Up @@ -121,7 +122,7 @@ class PlanSerializer(serializers.Serializer):
benefits = serializers.JSONField(read_only=True)
quantity = serializers.IntegerField(required=False)

def validate_value(self, value):
def validate_value(self, value: str) -> str:
current_org = self.context["view"].owner
current_owner = self.context["request"].current_owner

Expand All @@ -140,7 +141,7 @@ def validate_value(self, value):
)
return value

def validate(self, plan):
def validate(self, plan: Dict[str, Any]) -> Dict[str, Any]:
current_org = self.context["view"].owner
if current_org.account:
raise serializers.ValidationError(
Expand Down Expand Up @@ -206,7 +207,7 @@ class StripeScheduledPhaseSerializer(serializers.Serializer):
plan = serializers.SerializerMethodField()
quantity = serializers.SerializerMethodField()

def get_plan(self, phase):
def get_plan(self, phase: Dict[str, Any]) -> str:
plan_id = phase["items"][0]["plan"]
stripe_plan_dict = settings.STRIPE_PLAN_IDS
plan_name = list(stripe_plan_dict.keys())[
Expand All @@ -215,15 +216,15 @@ def get_plan(self, phase):
marketing_plan_name = PAID_PLANS[plan_name].billing_rate
return marketing_plan_name

def get_quantity(self, phase):
def get_quantity(self, phase: Dict[str, Any]) -> int:
return phase["items"][0]["quantity"]


class ScheduleDetailSerializer(serializers.Serializer):
id = serializers.CharField()
scheduled_phase = serializers.SerializerMethodField()

def get_scheduled_phase(self, schedule):
def get_scheduled_phase(self, schedule: Dict[str, Any]) -> Dict[str, Any] | None:
if len(schedule["phases"]) > 1:
return StripeScheduledPhaseSerializer(schedule["phases"][-1]).data
else:
Expand Down Expand Up @@ -291,44 +292,44 @@ class Meta:
"uses_invoice",
)

def _get_billing(self):
def _get_billing(self) -> BillingService:
current_owner = self.context["request"].current_owner
return BillingService(requesting_user=current_owner)

def get_subscription_detail(self, owner):
def get_subscription_detail(self, owner: Owner) -> Dict[str, Any] | None:
subscription_detail = self._get_billing().get_subscription(owner)
if subscription_detail:
return SubscriptionDetailSerializer(subscription_detail).data

def get_schedule_detail(self, owner):
def get_schedule_detail(self, owner: Owner) -> Dict[str, Any] | None:
schedule_detail = self._get_billing().get_schedule(owner)
if schedule_detail:
return ScheduleDetailSerializer(schedule_detail).data

def get_checkout_session_id(self, _):
def get_checkout_session_id(self, _: Any) -> str:
return self.context.get("checkout_session_id")

def get_activated_student_count(self, owner):
def get_activated_student_count(self, owner: Owner) -> int:
if owner.account:
return owner.account.activated_student_count
return owner.activated_student_count

def get_activated_user_count(self, owner):
def get_activated_user_count(self, owner: Owner) -> int:
if owner.account:
return owner.account.activated_user_count
return owner.activated_user_count

def get_delinquent(self, owner):
def get_delinquent(self, owner: Owner) -> bool:
if owner.account:
return owner.account.is_delinquent
return owner.delinquent

def get_uses_invoice(self, owner):
def get_uses_invoice(self, owner: Owner) -> bool:
if owner.account:
return owner.account.invoice_billing.filter(is_active=True).exists()
return owner.uses_invoice

def update(self, instance, validated_data):
def update(self, instance: Owner, validated_data: Dict[str, Any]) -> object:
if "pretty_plan" in validated_data:
desired_plan = validated_data.pop("pretty_plan")
checkout_session_id_or_none = self._get_billing().update_plan(
Expand Down Expand Up @@ -367,7 +368,7 @@ class Meta:
"last_pull_timestamp",
)

def update(self, instance, validated_data):
def update(self, instance: Owner, validated_data: Dict[str, Any]) -> object:
owner = self.context["view"].owner

if "activated" in validated_data:
Expand All @@ -391,7 +392,7 @@ def update(self, instance, validated_data):
# Re-fetch from DB to set activated and admin fields
return self.context["view"].get_object()

def get_last_pull_timestamp(self, obj):
def get_last_pull_timestamp(self, obj: Owner) -> str | None:
# this field comes from an annotation that may not always be applied to the queryset
if hasattr(obj, "last_pull_timestamp"):
return obj.last_pull_timestamp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def convert_yaml_to_dict(self, yaml_input: str) -> Optional[dict]:
message = f"Error at {str(e.error_location)}: {e.error_message}"
raise ValidationError(message)

def yaml_side_effects(self, old_yaml: dict, new_yaml: dict):
def yaml_side_effects(self, old_yaml: dict | None, new_yaml: dict | None) -> None:
old_yaml_branch = old_yaml and old_yaml.get("codecov", {}).get("branch")
new_yaml_branch = new_yaml and new_yaml.get("codecov", {}).get("branch")

Expand Down
3 changes: 2 additions & 1 deletion codecov_auth/management/commands/set_trial_status_values.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any

from django.core.management.base import BaseCommand, CommandParser
from django.db.models import Q
Expand All @@ -20,7 +21,7 @@ class Command(BaseCommand):
def add_arguments(self, parser: CommandParser) -> None:
parser.add_argument("trial_status_type", type=str)

def handle(self, *args, **options) -> None:
def handle(self, *args: Any, **options: Any) -> None:
trial_status_type = options.get("trial_status_type", {})

# NOT_STARTED
Expand Down
4 changes: 2 additions & 2 deletions core/commands/pull/interactors/fetch_pull_request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime, timedelta

from shared.django_apps.core.models import Pull
from shared.django_apps.core.models import Pull, Repository

from codecov.commands.base import BaseInteractor
from codecov.db import sync_to_async
Expand All @@ -17,7 +17,7 @@ def _should_sync_pull(self, pull: Pull | None) -> bool:
)

@sync_to_async
def execute(self, repository, id):
def execute(self, repository: Repository, id: int) -> Pull:
pull = repository.pull_requests.filter(pullid=id).first()
if self._should_sync_pull(pull):
TaskService().pulls_sync(repository.repoid, id)
Expand Down
4 changes: 2 additions & 2 deletions core/commands/repository/interactors/erase_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class EraseRepositoryInteractor(BaseInteractor):
def validate_owner(self, owner: Owner):
def validate_owner(self, owner: Owner) -> None:
if not current_user_part_of_org(self.current_owner, owner):
raise Unauthorized()

Expand All @@ -23,7 +23,7 @@ def validate_owner(self, owner: Owner):
raise Unauthorized()

@sync_to_async
def execute(self, repo_name: str, owner: Owner):
def execute(self, repo_name: str, owner: Owner) -> None:
self.validate_owner(owner)
repo = Repository.objects.filter(author_id=owner.pk, name=repo_name).first()
if not repo:
Expand Down
2 changes: 1 addition & 1 deletion graphql_api/actions/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from timeseries.models import Interval, MeasurementName


def flags_for_repo(repository: Repository, filters: Mapping = None) -> QuerySet:
def flags_for_repo(repository: Repository, filters: Mapping = {}) -> QuerySet:
queryset = RepositoryFlag.objects.filter(
repository=repository,
deleted__isnot=True,
Expand Down
7 changes: 5 additions & 2 deletions graphql_api/types/branch/branch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

from ariadne import ObjectType
from graphql import GraphQLResolveInfo

from core.models import Branch, Commit
from graphql_api.dataloader.commit import CommitLoader
Expand All @@ -9,13 +10,15 @@


@branch_bindable.field("headSha")
def resolve_head_sha(branch: Branch, info) -> str:
def resolve_head_sha(branch: Branch, info: GraphQLResolveInfo) -> str:
head = branch.head
return head


@branch_bindable.field("head")
async def resolve_head_commit(branch: Branch, info) -> Optional[Commit]:
async def resolve_head_commit(
branch: Branch, info: GraphQLResolveInfo
) -> Optional[Commit]:
head = branch.head
if head:
loader = CommitLoader.loader(info, branch.repository_id)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, Dict

from ariadne import UnionType
from graphql import GraphQLResolveInfo

from graphql_api.helpers.mutation import (
require_authenticated,
Expand All @@ -9,7 +12,9 @@

@wrap_error_handling_mutation
@require_authenticated
async def resolve_save_okta_config(_, info, input):
async def resolve_save_okta_config(
_: Any, info: GraphQLResolveInfo, input: Dict[str, Any]
) -> None:
command = info.context["executor"].get_command("owner")
return await command.save_okta_config(input)

Expand Down
5 changes: 4 additions & 1 deletion graphs/mixins.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Any

from django.http import HttpResponse
from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response


class GraphBadgeAPIMixin(object):
def get(self, request, *args, **kwargs):
def get(self, request: Request, *args: Any, **kwargs: Any) -> Response:
ext = self.kwargs.get("ext")
if ext not in self.extensions:
return Response(
Expand Down
Loading

0 comments on commit 47a0933

Please sign in to comment.