Skip to content

Commit

Permalink
Merge branch 'main' into jan_06_cache_config
Browse files Browse the repository at this point in the history
  • Loading branch information
JerrySentry authored Jan 10, 2025
2 parents 98f08fc + a0c8267 commit 3ac41b8
Show file tree
Hide file tree
Showing 32 changed files with 257 additions and 166 deletions.
3 changes: 2 additions & 1 deletion api/internal/commit/serializers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Dict, List

import shared.reports.api_report_service as report_service
from rest_framework import serializers
Expand Down Expand Up @@ -32,7 +33,7 @@ class Meta:
class CommitWithFileLevelReportSerializer(CommitSerializer):
report = serializers.SerializerMethodField()

def get_report(self, commit: Commit):
def get_report(self, commit: Commit) -> Dict[str, List[Dict] | Dict] | None:
report = report_service.build_report_from_commit(commit)
if report is None:
return None
Expand Down
10 changes: 6 additions & 4 deletions api/internal/feature/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import pickle
from typing import Any, Dict, List

from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from shared.django_apps.rollouts.models import FeatureFlag
Expand All @@ -20,15 +22,15 @@ class FeaturesView(APIView):
skip_feature_cache = get_config("setup", "skip_feature_cache", default=False)
timeout = 300

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.redis = get_redis_connection()
super().__init__(*args, **kwargs)

def get_many_from_redis(self, keys):
def get_many_from_redis(self, keys: List) -> Dict[str, Any]:
ret = self.redis.mget(keys)
return {k: pickle.loads(v) for k, v in zip(keys, ret) if v is not None}

def set_many_to_redis(self, data):
def set_many_to_redis(self, data: Dict[str, Any]) -> None:
pipeline = self.redis.pipeline()
pipeline.mset({k: pickle.dumps(v) for k, v in data.items()})

Expand All @@ -38,7 +40,7 @@ def set_many_to_redis(self, data):
pipeline.expire(key, self.timeout)
pipeline.execute()

def post(self, request):
def post(self, request: Request) -> Response:
serializer = FeatureRequestSerializer(data=request.data)
if serializer.is_valid():
flag_evaluations = {}
Expand Down
41 changes: 21 additions & 20 deletions api/internal/owner/serializers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from datetime import datetime
from typing import Any, Dict

from dateutil.relativedelta import relativedelta
from django.conf import settings
Expand Down Expand Up @@ -37,7 +38,7 @@ class Meta:

read_only_fields = fields

def get_stats(self, obj):
def get_stats(self, obj: Owner) -> str | None:
if obj.cache and "stats" in obj.cache:
return obj.cache["stats"]

Expand All @@ -50,7 +51,7 @@ class StripeLineItemSerializer(serializers.Serializer):
plan_name = serializers.SerializerMethodField()
quantity = serializers.IntegerField()

def get_plan_name(self, line_item):
def get_plan_name(self, line_item: Dict[str, str]) -> str | None:
plan = line_item.get("plan")
if plan:
return plan.get("name")
Expand Down Expand Up @@ -85,7 +86,7 @@ class StripeDiscountSerializer(serializers.Serializer):
duration_in_months = serializers.IntegerField(source="coupon.duration_in_months")
expires = serializers.SerializerMethodField()

def get_expires(self, customer):
def get_expires(self, customer: Dict[str, Dict]) -> int | None:
coupon = customer.get("coupon")
if coupon:
months = coupon.get("duration_in_months")
Expand Down Expand Up @@ -121,7 +122,7 @@ class PlanSerializer(serializers.Serializer):
benefits = serializers.JSONField(read_only=True)
quantity = serializers.IntegerField(required=False)

def validate_value(self, value):
def validate_value(self, value: str) -> str:
current_org = self.context["view"].owner
current_owner = self.context["request"].current_owner

Expand All @@ -136,11 +137,11 @@ def validate_value(self, value):
extra=dict(owner_id=current_owner.pk, plan=value),
)
raise serializers.ValidationError(
f"Invalid value for plan: {value}; " f"must be one of {plan_values}"
f"Invalid value for plan: {value}; must be one of {plan_values}"
)
return value

def validate(self, plan):
def validate(self, plan: Dict[str, Any]) -> Dict[str, Any]:
current_org = self.context["view"].owner
if current_org.account:
raise serializers.ValidationError(
Expand Down Expand Up @@ -206,7 +207,7 @@ class StripeScheduledPhaseSerializer(serializers.Serializer):
plan = serializers.SerializerMethodField()
quantity = serializers.SerializerMethodField()

def get_plan(self, phase):
def get_plan(self, phase: Dict[str, Any]) -> str:
plan_id = phase["items"][0]["plan"]
stripe_plan_dict = settings.STRIPE_PLAN_IDS
plan_name = list(stripe_plan_dict.keys())[
Expand All @@ -215,15 +216,15 @@ def get_plan(self, phase):
marketing_plan_name = PAID_PLANS[plan_name].billing_rate
return marketing_plan_name

def get_quantity(self, phase):
def get_quantity(self, phase: Dict[str, Any]) -> int:
return phase["items"][0]["quantity"]


class ScheduleDetailSerializer(serializers.Serializer):
id = serializers.CharField()
scheduled_phase = serializers.SerializerMethodField()

def get_scheduled_phase(self, schedule):
def get_scheduled_phase(self, schedule: Dict[str, Any]) -> Dict[str, Any] | None:
if len(schedule["phases"]) > 1:
return StripeScheduledPhaseSerializer(schedule["phases"][-1]).data
else:
Expand Down Expand Up @@ -291,44 +292,44 @@ class Meta:
"uses_invoice",
)

def _get_billing(self):
def _get_billing(self) -> BillingService:
current_owner = self.context["request"].current_owner
return BillingService(requesting_user=current_owner)

def get_subscription_detail(self, owner):
def get_subscription_detail(self, owner: Owner) -> Dict[str, Any] | None:
subscription_detail = self._get_billing().get_subscription(owner)
if subscription_detail:
return SubscriptionDetailSerializer(subscription_detail).data

def get_schedule_detail(self, owner):
def get_schedule_detail(self, owner: Owner) -> Dict[str, Any] | None:
schedule_detail = self._get_billing().get_schedule(owner)
if schedule_detail:
return ScheduleDetailSerializer(schedule_detail).data

def get_checkout_session_id(self, _):
def get_checkout_session_id(self, _: Any) -> str:
return self.context.get("checkout_session_id")

def get_activated_student_count(self, owner):
def get_activated_student_count(self, owner: Owner) -> int:
if owner.account:
return owner.account.activated_student_count
return owner.activated_student_count

def get_activated_user_count(self, owner):
def get_activated_user_count(self, owner: Owner) -> int:
if owner.account:
return owner.account.activated_user_count
return owner.activated_user_count

def get_delinquent(self, owner):
def get_delinquent(self, owner: Owner) -> bool:
if owner.account:
return owner.account.is_delinquent
return owner.delinquent

def get_uses_invoice(self, owner):
def get_uses_invoice(self, owner: Owner) -> bool:
if owner.account:
return owner.account.invoice_billing.filter(is_active=True).exists()
return owner.uses_invoice

def update(self, instance, validated_data):
def update(self, instance: Owner, validated_data: Dict[str, Any]) -> object:
if "pretty_plan" in validated_data:
desired_plan = validated_data.pop("pretty_plan")
checkout_session_id_or_none = self._get_billing().update_plan(
Expand Down Expand Up @@ -367,7 +368,7 @@ class Meta:
"last_pull_timestamp",
)

def update(self, instance, validated_data):
def update(self, instance: Owner, validated_data: Dict[str, Any]) -> object:
owner = self.context["view"].owner

if "activated" in validated_data:
Expand All @@ -391,7 +392,7 @@ def update(self, instance, validated_data):
# Re-fetch from DB to set activated and admin fields
return self.context["view"].get_object()

def get_last_pull_timestamp(self, obj):
def get_last_pull_timestamp(self, obj: Owner) -> str | None:
# this field comes from an annotation that may not always be applied to the queryset
if hasattr(obj, "last_pull_timestamp"):
return obj.last_pull_timestamp
10 changes: 5 additions & 5 deletions codecov_auth/commands/owner/interactors/save_terms_agreement.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from typing import Any, Optional

from django.utils import timezone

Expand All @@ -20,7 +20,7 @@ class TermsAgreementInput:
class SaveTermsAgreementInteractor(BaseInteractor):
requires_service = False

def validate(self, input: TermsAgreementInput):
def validate(self, input: TermsAgreementInput) -> None:
valid_customer_intents = ["Business", "BUSINESS", "Personal", "PERSONAL"]
if (
input.customer_intent
Expand All @@ -30,7 +30,7 @@ def validate(self, input: TermsAgreementInput):
if not self.current_user.is_authenticated:
raise Unauthenticated()

def update_terms_agreement(self, input: TermsAgreementInput):
def update_terms_agreement(self, input: TermsAgreementInput) -> None:
self.current_user.terms_agreement = input.terms_agreement
self.current_user.terms_agreement_at = timezone.now()
self.current_user.customer_intent = input.customer_intent
Expand All @@ -44,14 +44,14 @@ def update_terms_agreement(self, input: TermsAgreementInput):
if input.marketing_consent:
self.send_data_to_marketo()

def send_data_to_marketo(self):
def send_data_to_marketo(self) -> None:
event_data = {
"email": self.current_user.email,
}
AnalyticsService().opt_in_email(self.current_user.id, event_data)

@sync_to_async
def execute(self, input):
def execute(self, input: Any) -> None:
typed_input = TermsAgreementInput(
business_email=input.get("business_email"),
terms_agreement=input.get("terms_agreement"),
Expand Down
6 changes: 4 additions & 2 deletions codecov_auth/commands/owner/interactors/set_yaml_on_owner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,9 @@ def convert_yaml_to_dict(self, yaml_input: str) -> Optional[dict]:
except yaml.scanner.ScannerError as e:
line = e.problem_mark.line
column = e.problem_mark.column
message = f"Syntax error at line {line+1}, column {column+1}: {e.problem}"
message = (
f"Syntax error at line {line + 1}, column {column + 1}: {e.problem}"
)
raise ValidationError(message)
if not yaml_dict:
return None
Expand All @@ -52,7 +54,7 @@ def convert_yaml_to_dict(self, yaml_input: str) -> Optional[dict]:
message = f"Error at {str(e.error_location)}: {e.error_message}"
raise ValidationError(message)

def yaml_side_effects(self, old_yaml: dict, new_yaml: dict):
def yaml_side_effects(self, old_yaml: dict | None, new_yaml: dict | None) -> None:
old_yaml_branch = old_yaml and old_yaml.get("codecov", {}).get("branch")
new_yaml_branch = new_yaml and new_yaml.get("codecov", {}).get("branch")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def test_user_is_part_of_org_and_yaml_is_good(self):
"codecov": {
"require_ci_to_pass": True,
},
"to_string": "\n" "codecov:\n" " require_ci_to_pass: yes\n",
"to_string": "\ncodecov:\n require_ci_to_pass: yes\n",
}

async def test_user_is_part_of_org_and_yaml_has_quotes(self):
Expand All @@ -109,7 +109,7 @@ async def test_user_is_part_of_org_and_yaml_has_quotes(self):
"codecov": {
"bot": "codecov",
},
"to_string": "\n" "codecov:\n" " bot: 'codecov'\n",
"to_string": "\ncodecov:\n bot: 'codecov'\n",
}

async def test_user_is_part_of_org_and_yaml_is_empty(self):
Expand Down
3 changes: 2 additions & 1 deletion codecov_auth/management/commands/set_trial_status_values.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any

from django.core.management.base import BaseCommand, CommandParser
from django.db.models import Q
Expand All @@ -20,7 +21,7 @@ class Command(BaseCommand):
def add_arguments(self, parser: CommandParser) -> None:
parser.add_argument("trial_status_type", type=str)

def handle(self, *args, **options) -> None:
def handle(self, *args: Any, **options: Any) -> None:
trial_status_type = options.get("trial_status_type", {})

# NOT_STARTED
Expand Down
4 changes: 2 additions & 2 deletions codecov_auth/signals.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Type
from typing import Any, Dict, Optional, Type, cast

from django.db.models.signals import post_save
from django.dispatch import receiver
Expand Down Expand Up @@ -38,7 +38,7 @@ def update_owner(
"""
Shelter tracks a limited set of Owner fields - only update if those fields have changed.
"""
created: bool = kwargs["created"]
created: bool = cast(bool, kwargs["created"])
tracked_fields = [
"upload_token_required_for_public_repos",
"username",
Expand Down
4 changes: 2 additions & 2 deletions codecov_auth/views/sentry.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _perform_login(self, request: HttpRequest) -> HttpResponse:
# user has not connected any owners yet
return redirect(f"{settings.CODECOV_DASHBOARD_URL}/sync")

def _login_user(self, request: HttpRequest, user_data: dict):
def _login_user(self, request: HttpRequest, user_data: dict) -> User:
sentry_id = user_data["user"]["id"]
user_name = user_data["user"].get("name")
user_email = user_data["user"].get("email")
Expand Down Expand Up @@ -177,7 +177,7 @@ def _login_user(self, request: HttpRequest, user_data: dict):
login(request, current_user)
return current_user

def get(self, request):
def get(self, request: HttpRequest) -> HttpResponse:
if request.GET.get("code"):
return self._perform_login(request)
else:
Expand Down
6 changes: 4 additions & 2 deletions core/commands/commit/interactors/get_file_content.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import logging
from typing import Any, Coroutine

from codecov.commands.base import BaseInteractor
from core.models import Commit
from services.repo_providers import RepoProviderService

log = logging.getLogger(__name__)


class GetFileContentInteractor(BaseInteractor):
async def get_file_from_service(self, commit, path):
async def get_file_from_service(self, commit: Commit, path: str) -> str | None:
try:
repository_service = await RepoProviderService().async_get_adapter(
owner=self.current_owner, repo=commit.repository
Expand All @@ -27,5 +29,5 @@ async def get_file_from_service(self, commit, path):
)
return None

def execute(self, commit, path):
def execute(self, commit: Commit, path: str) -> Coroutine[Any, Any, str | None]:
return self.get_file_from_service(commit, path)
4 changes: 2 additions & 2 deletions core/commands/pull/interactors/fetch_pull_request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime, timedelta

from shared.django_apps.core.models import Pull
from shared.django_apps.core.models import Pull, Repository

from codecov.commands.base import BaseInteractor
from codecov.db import sync_to_async
Expand All @@ -17,7 +17,7 @@ def _should_sync_pull(self, pull: Pull | None) -> bool:
)

@sync_to_async
def execute(self, repository, id):
def execute(self, repository: Repository, id: int) -> Pull:
pull = repository.pull_requests.filter(pullid=id).first()
if self._should_sync_pull(pull):
TaskService().pulls_sync(repository.repoid, id)
Expand Down
4 changes: 2 additions & 2 deletions core/commands/repository/interactors/erase_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class EraseRepositoryInteractor(BaseInteractor):
def validate_owner(self, owner: Owner):
def validate_owner(self, owner: Owner) -> None:
if not current_user_part_of_org(self.current_owner, owner):
raise Unauthorized()

Expand All @@ -23,7 +23,7 @@ def validate_owner(self, owner: Owner):
raise Unauthorized()

@sync_to_async
def execute(self, repo_name: str, owner: Owner):
def execute(self, repo_name: str, owner: Owner) -> None:
self.validate_owner(owner)
repo = Repository.objects.filter(author_id=owner.pk, name=repo_name).first()
if not repo:
Expand Down
Loading

0 comments on commit 3ac41b8

Please sign in to comment.