From 47a0933a7641ea685ce21d9c6f9a0ef3ba6c8b62 Mon Sep 17 00:00:00 2001 From: JerrySentry <142266253+JerrySentry@users.noreply.github.com> Date: Fri, 10 Jan 2025 12:50:45 -0500 Subject: [PATCH] API Improvements: Fix typings for various files pt5 (#1087) --- api/internal/feature/views.py | 10 +- api/internal/owner/serializers.py | 39 +++---- .../owner/interactors/set_yaml_on_owner.py | 2 +- .../commands/set_trial_status_values.py | 3 +- .../pull/interactors/fetch_pull_request.py | 4 +- .../interactors/erase_repository.py | 4 +- graphql_api/actions/flags.py | 2 +- graphql_api/types/branch/branch.py | 7 +- .../save_okta_config/save_okta_config.py | 7 +- graphs/mixins.py | 5 +- upload/helpers.py | 102 +++++++++++------- upload/tokenless/appveyor.py | 7 +- upload/tokenless/cirrus.py | 5 +- upload/views/base.py | 7 +- utils/test_utils.py | 14 +-- 15 files changed, 131 insertions(+), 87 deletions(-) diff --git a/api/internal/feature/views.py b/api/internal/feature/views.py index 2c08c01718..146a9be409 100644 --- a/api/internal/feature/views.py +++ b/api/internal/feature/views.py @@ -1,7 +1,9 @@ import logging import pickle +from typing import Any, Dict, List from rest_framework import status +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.views import APIView from shared.django_apps.rollouts.models import FeatureFlag @@ -20,15 +22,15 @@ class FeaturesView(APIView): skip_feature_cache = get_config("setup", "skip_feature_cache", default=False) timeout = 300 - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.redis = get_redis_connection() super().__init__(*args, **kwargs) - def get_many_from_redis(self, keys): + def get_many_from_redis(self, keys: List) -> Dict[str, Any]: ret = self.redis.mget(keys) return {k: pickle.loads(v) for k, v in zip(keys, ret) if v is not None} - def set_many_to_redis(self, data): + def set_many_to_redis(self, data: Dict[str, Any]) -> None: pipeline = self.redis.pipeline() pipeline.mset({k: pickle.dumps(v) for k, v in data.items()}) @@ -38,7 +40,7 @@ def set_many_to_redis(self, data): pipeline.expire(key, self.timeout) pipeline.execute() - def post(self, request): + def post(self, request: Request) -> Response: serializer = FeatureRequestSerializer(data=request.data) if serializer.is_valid(): flag_evaluations = {} diff --git a/api/internal/owner/serializers.py b/api/internal/owner/serializers.py index 01023194f2..3d21367a2c 100644 --- a/api/internal/owner/serializers.py +++ b/api/internal/owner/serializers.py @@ -1,5 +1,6 @@ import logging from datetime import datetime +from typing import Any, Dict from dateutil.relativedelta import relativedelta from django.conf import settings @@ -37,7 +38,7 @@ class Meta: read_only_fields = fields - def get_stats(self, obj): + def get_stats(self, obj: Owner) -> str | None: if obj.cache and "stats" in obj.cache: return obj.cache["stats"] @@ -50,7 +51,7 @@ class StripeLineItemSerializer(serializers.Serializer): plan_name = serializers.SerializerMethodField() quantity = serializers.IntegerField() - def get_plan_name(self, line_item): + def get_plan_name(self, line_item: Dict[str, str]) -> str | None: plan = line_item.get("plan") if plan: return plan.get("name") @@ -85,7 +86,7 @@ class StripeDiscountSerializer(serializers.Serializer): duration_in_months = serializers.IntegerField(source="coupon.duration_in_months") expires = serializers.SerializerMethodField() - def get_expires(self, customer): + def get_expires(self, customer: Dict[str, Dict]) -> int | None: coupon = customer.get("coupon") if coupon: months = coupon.get("duration_in_months") @@ -121,7 +122,7 @@ class PlanSerializer(serializers.Serializer): benefits = serializers.JSONField(read_only=True) quantity = serializers.IntegerField(required=False) - def validate_value(self, value): + def validate_value(self, value: str) -> str: current_org = self.context["view"].owner current_owner = self.context["request"].current_owner @@ -140,7 +141,7 @@ def validate_value(self, value): ) return value - def validate(self, plan): + def validate(self, plan: Dict[str, Any]) -> Dict[str, Any]: current_org = self.context["view"].owner if current_org.account: raise serializers.ValidationError( @@ -206,7 +207,7 @@ class StripeScheduledPhaseSerializer(serializers.Serializer): plan = serializers.SerializerMethodField() quantity = serializers.SerializerMethodField() - def get_plan(self, phase): + def get_plan(self, phase: Dict[str, Any]) -> str: plan_id = phase["items"][0]["plan"] stripe_plan_dict = settings.STRIPE_PLAN_IDS plan_name = list(stripe_plan_dict.keys())[ @@ -215,7 +216,7 @@ def get_plan(self, phase): marketing_plan_name = PAID_PLANS[plan_name].billing_rate return marketing_plan_name - def get_quantity(self, phase): + def get_quantity(self, phase: Dict[str, Any]) -> int: return phase["items"][0]["quantity"] @@ -223,7 +224,7 @@ class ScheduleDetailSerializer(serializers.Serializer): id = serializers.CharField() scheduled_phase = serializers.SerializerMethodField() - def get_scheduled_phase(self, schedule): + def get_scheduled_phase(self, schedule: Dict[str, Any]) -> Dict[str, Any] | None: if len(schedule["phases"]) > 1: return StripeScheduledPhaseSerializer(schedule["phases"][-1]).data else: @@ -291,44 +292,44 @@ class Meta: "uses_invoice", ) - def _get_billing(self): + def _get_billing(self) -> BillingService: current_owner = self.context["request"].current_owner return BillingService(requesting_user=current_owner) - def get_subscription_detail(self, owner): + def get_subscription_detail(self, owner: Owner) -> Dict[str, Any] | None: subscription_detail = self._get_billing().get_subscription(owner) if subscription_detail: return SubscriptionDetailSerializer(subscription_detail).data - def get_schedule_detail(self, owner): + def get_schedule_detail(self, owner: Owner) -> Dict[str, Any] | None: schedule_detail = self._get_billing().get_schedule(owner) if schedule_detail: return ScheduleDetailSerializer(schedule_detail).data - def get_checkout_session_id(self, _): + def get_checkout_session_id(self, _: Any) -> str: return self.context.get("checkout_session_id") - def get_activated_student_count(self, owner): + def get_activated_student_count(self, owner: Owner) -> int: if owner.account: return owner.account.activated_student_count return owner.activated_student_count - def get_activated_user_count(self, owner): + def get_activated_user_count(self, owner: Owner) -> int: if owner.account: return owner.account.activated_user_count return owner.activated_user_count - def get_delinquent(self, owner): + def get_delinquent(self, owner: Owner) -> bool: if owner.account: return owner.account.is_delinquent return owner.delinquent - def get_uses_invoice(self, owner): + def get_uses_invoice(self, owner: Owner) -> bool: if owner.account: return owner.account.invoice_billing.filter(is_active=True).exists() return owner.uses_invoice - def update(self, instance, validated_data): + def update(self, instance: Owner, validated_data: Dict[str, Any]) -> object: if "pretty_plan" in validated_data: desired_plan = validated_data.pop("pretty_plan") checkout_session_id_or_none = self._get_billing().update_plan( @@ -367,7 +368,7 @@ class Meta: "last_pull_timestamp", ) - def update(self, instance, validated_data): + def update(self, instance: Owner, validated_data: Dict[str, Any]) -> object: owner = self.context["view"].owner if "activated" in validated_data: @@ -391,7 +392,7 @@ def update(self, instance, validated_data): # Re-fetch from DB to set activated and admin fields return self.context["view"].get_object() - def get_last_pull_timestamp(self, obj): + def get_last_pull_timestamp(self, obj: Owner) -> str | None: # this field comes from an annotation that may not always be applied to the queryset if hasattr(obj, "last_pull_timestamp"): return obj.last_pull_timestamp diff --git a/codecov_auth/commands/owner/interactors/set_yaml_on_owner.py b/codecov_auth/commands/owner/interactors/set_yaml_on_owner.py index 236ecf604e..566d003936 100644 --- a/codecov_auth/commands/owner/interactors/set_yaml_on_owner.py +++ b/codecov_auth/commands/owner/interactors/set_yaml_on_owner.py @@ -54,7 +54,7 @@ def convert_yaml_to_dict(self, yaml_input: str) -> Optional[dict]: message = f"Error at {str(e.error_location)}: {e.error_message}" raise ValidationError(message) - def yaml_side_effects(self, old_yaml: dict, new_yaml: dict): + def yaml_side_effects(self, old_yaml: dict | None, new_yaml: dict | None) -> None: old_yaml_branch = old_yaml and old_yaml.get("codecov", {}).get("branch") new_yaml_branch = new_yaml and new_yaml.get("codecov", {}).get("branch") diff --git a/codecov_auth/management/commands/set_trial_status_values.py b/codecov_auth/management/commands/set_trial_status_values.py index 46f3e29d63..7872bbce57 100644 --- a/codecov_auth/management/commands/set_trial_status_values.py +++ b/codecov_auth/management/commands/set_trial_status_values.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Any from django.core.management.base import BaseCommand, CommandParser from django.db.models import Q @@ -20,7 +21,7 @@ class Command(BaseCommand): def add_arguments(self, parser: CommandParser) -> None: parser.add_argument("trial_status_type", type=str) - def handle(self, *args, **options) -> None: + def handle(self, *args: Any, **options: Any) -> None: trial_status_type = options.get("trial_status_type", {}) # NOT_STARTED diff --git a/core/commands/pull/interactors/fetch_pull_request.py b/core/commands/pull/interactors/fetch_pull_request.py index deb6824e2c..a5f0dafc85 100644 --- a/core/commands/pull/interactors/fetch_pull_request.py +++ b/core/commands/pull/interactors/fetch_pull_request.py @@ -1,6 +1,6 @@ from datetime import datetime, timedelta -from shared.django_apps.core.models import Pull +from shared.django_apps.core.models import Pull, Repository from codecov.commands.base import BaseInteractor from codecov.db import sync_to_async @@ -17,7 +17,7 @@ def _should_sync_pull(self, pull: Pull | None) -> bool: ) @sync_to_async - def execute(self, repository, id): + def execute(self, repository: Repository, id: int) -> Pull: pull = repository.pull_requests.filter(pullid=id).first() if self._should_sync_pull(pull): TaskService().pulls_sync(repository.repoid, id) diff --git a/core/commands/repository/interactors/erase_repository.py b/core/commands/repository/interactors/erase_repository.py index 332051ce6d..5a3423807a 100644 --- a/core/commands/repository/interactors/erase_repository.py +++ b/core/commands/repository/interactors/erase_repository.py @@ -11,7 +11,7 @@ class EraseRepositoryInteractor(BaseInteractor): - def validate_owner(self, owner: Owner): + def validate_owner(self, owner: Owner) -> None: if not current_user_part_of_org(self.current_owner, owner): raise Unauthorized() @@ -23,7 +23,7 @@ def validate_owner(self, owner: Owner): raise Unauthorized() @sync_to_async - def execute(self, repo_name: str, owner: Owner): + def execute(self, repo_name: str, owner: Owner) -> None: self.validate_owner(owner) repo = Repository.objects.filter(author_id=owner.pk, name=repo_name).first() if not repo: diff --git a/graphql_api/actions/flags.py b/graphql_api/actions/flags.py index 92c18f32ef..de7f18981e 100644 --- a/graphql_api/actions/flags.py +++ b/graphql_api/actions/flags.py @@ -10,7 +10,7 @@ from timeseries.models import Interval, MeasurementName -def flags_for_repo(repository: Repository, filters: Mapping = None) -> QuerySet: +def flags_for_repo(repository: Repository, filters: Mapping = {}) -> QuerySet: queryset = RepositoryFlag.objects.filter( repository=repository, deleted__isnot=True, diff --git a/graphql_api/types/branch/branch.py b/graphql_api/types/branch/branch.py index 5d76428d36..c396cdab65 100644 --- a/graphql_api/types/branch/branch.py +++ b/graphql_api/types/branch/branch.py @@ -1,6 +1,7 @@ from typing import Optional from ariadne import ObjectType +from graphql import GraphQLResolveInfo from core.models import Branch, Commit from graphql_api.dataloader.commit import CommitLoader @@ -9,13 +10,15 @@ @branch_bindable.field("headSha") -def resolve_head_sha(branch: Branch, info) -> str: +def resolve_head_sha(branch: Branch, info: GraphQLResolveInfo) -> str: head = branch.head return head @branch_bindable.field("head") -async def resolve_head_commit(branch: Branch, info) -> Optional[Commit]: +async def resolve_head_commit( + branch: Branch, info: GraphQLResolveInfo +) -> Optional[Commit]: head = branch.head if head: loader = CommitLoader.loader(info, branch.repository_id) diff --git a/graphql_api/types/mutation/save_okta_config/save_okta_config.py b/graphql_api/types/mutation/save_okta_config/save_okta_config.py index 3cbd5708a4..8413f40667 100644 --- a/graphql_api/types/mutation/save_okta_config/save_okta_config.py +++ b/graphql_api/types/mutation/save_okta_config/save_okta_config.py @@ -1,4 +1,7 @@ +from typing import Any, Dict + from ariadne import UnionType +from graphql import GraphQLResolveInfo from graphql_api.helpers.mutation import ( require_authenticated, @@ -9,7 +12,9 @@ @wrap_error_handling_mutation @require_authenticated -async def resolve_save_okta_config(_, info, input): +async def resolve_save_okta_config( + _: Any, info: GraphQLResolveInfo, input: Dict[str, Any] +) -> None: command = info.context["executor"].get_command("owner") return await command.save_okta_config(input) diff --git a/graphs/mixins.py b/graphs/mixins.py index 5e7613ebfb..e1b74563dc 100644 --- a/graphs/mixins.py +++ b/graphs/mixins.py @@ -1,10 +1,13 @@ +from typing import Any + from django.http import HttpResponse from rest_framework import status +from rest_framework.request import Request from rest_framework.response import Response class GraphBadgeAPIMixin(object): - def get(self, request, *args, **kwargs): + def get(self, request: Request, *args: Any, **kwargs: Any) -> Response: ext = self.kwargs.get("ext") if ext not in self.extensions: return Response( diff --git a/upload/helpers.py b/upload/helpers.py index 428ffc9510..bccd92e034 100644 --- a/upload/helpers.py +++ b/upload/helpers.py @@ -1,7 +1,7 @@ import logging import re from json import dumps -from typing import Optional +from typing import Any, Dict, Optional import jwt from asgiref.sync import async_to_sync @@ -9,14 +9,18 @@ from django.conf import settings from django.core.exceptions import ObjectDoesNotExist from django.db.models import Q +from django.http import HttpRequest from django.utils import timezone from jwt import PyJWKClient, PyJWTError +from redis import Redis from rest_framework.exceptions import NotFound, Throttled, ValidationError from shared.github import InvalidInstallationError from shared.plan.constants import USER_PLAN_REPRESENTATIONS from shared.plan.service import PlanService from shared.reports.enums import UploadType +from shared.torngit.base import TorngitBaseAdapter from shared.torngit.exceptions import TorngitClientError, TorngitObjectNotFoundError +from shared.typings.oauth_token_types import OauthConsumerToken from shared.upload.utils import query_monthly_coverage_measurements from codecov_auth.models import ( @@ -51,7 +55,7 @@ redis = get_redis_connection() -def parse_params(data): +def parse_params(data: Dict[str, Any]) -> Dict[str, Any]: """ This function will validate the input request parameters and do some additional parsing/tranformation of the params. """ @@ -230,7 +234,7 @@ def parse_params(data): return v.document -def get_repo_with_github_actions_oidc_token(token): +def get_repo_with_github_actions_oidc_token(token: str) -> Repository: unverified_contents = jwt.decode(token, options={"verify_signature": False}) token_issuer = str(unverified_contents.get("iss")) if token_issuer == "https://token.actions.githubusercontent.com": @@ -259,7 +263,7 @@ def get_repo_with_github_actions_oidc_token(token): return repository -def determine_repo_for_upload(upload_params): +def determine_repo_for_upload(upload_params: Dict[str, Any]) -> Repository: token = upload_params.get("token") using_global_token = upload_params.get("using_global_token") service = upload_params.get("service") @@ -309,7 +313,9 @@ def determine_repo_for_upload(upload_params): """ -def determine_upload_branch_to_use(upload_params, repo_default_branch): +def determine_upload_branch_to_use( + upload_params: Dict[str, Any], repo_default_branch: str +) -> str | None: """ Do processing on the upload request parameters to determine which branch to use for the upload: - If no branch or PR were provided, use the default branch for the repository. @@ -330,7 +336,7 @@ def determine_upload_branch_to_use(upload_params, repo_default_branch): return None -def determine_upload_pr_to_use(upload_params): +def determine_upload_pr_to_use(upload_params: Dict[str, Any]) -> str | None: """ Do processing on the upload request parameters to determine which PR to use for the upload: - If a branch was provided and the branch name contains "pull" or "pr" followed by digits, extract the digits and use that as the PR number. @@ -370,7 +376,9 @@ def ghapp_installation_id_to_use(repository: Repository) -> Optional[str]: return repository.author.integration_id -def try_to_get_best_possible_bot_token(repository): +def try_to_get_best_possible_bot_token( + repository: Repository, +) -> OauthConsumerToken | Dict: ghapp_installation_id = ghapp_installation_id_to_use(repository) if ghapp_installation_id is not None: try: @@ -424,11 +432,15 @@ def try_to_get_best_possible_bot_token(repository): @async_to_sync -async def _get_git_commit_data(adapter, commit, token): +async def _get_git_commit_data( + adapter: TorngitBaseAdapter, commit: str, token: Optional[OauthConsumerToken | Dict] +) -> Dict[str, Any]: return await adapter.get_commit(commit, token) -def determine_upload_commit_to_use(upload_params, repository): +def determine_upload_commit_to_use( + upload_params: Dict[str, Any], repository: Repository +) -> str: """ Do processing on the upload request parameters to determine which commit to use for the upload: - If this is a merge commit on github, use the first commit SHA in the merge commit message. @@ -437,31 +449,28 @@ def determine_upload_commit_to_use(upload_params, repository): # Check if this is a merge commit and, if so, use the commitid of the commit being merged into per the merge commit message. # See https://docs.codecov.io/docs/merge-commits for more context. service = repository.author.service + commitid = upload_params.get("commit", "") if service.startswith("github") and not upload_params.get( "_did_change_merge_commit" ): token = try_to_get_best_possible_bot_token(repository) if token is None: - return upload_params.get("commit") + return commitid # Get the commit message from the git provider and check if it's structured like a merge commit message try: adapter = RepoProviderService().get_adapter( repository.author, repository, use_ssl=True, token=token ) - git_commit_data = _get_git_commit_data( - adapter, upload_params.get("commit"), token - ) + git_commit_data = _get_git_commit_data(adapter, commitid, token) except TorngitObjectNotFoundError: log.warning( "Unable to fetch commit. Not found", - extra=dict(commit=upload_params.get("commit")), + extra=dict(commit=commitid), ) - return upload_params.get("commit") + return commitid except TorngitClientError: - log.warning( - "Unable to fetch commit", extra=dict(commit=upload_params.get("commit")) - ) - return upload_params.get("commit") + log.warning("Unable to fetch commit", extra=dict(commit=commitid)) + return commitid git_commit_message = git_commit_data.get("message", "").strip() is_merge_commit = re.match(r"^Merge\s\w{40}\sinto\s\w{40}$", git_commit_message) @@ -472,7 +481,7 @@ def determine_upload_commit_to_use(upload_params, repository): log.info( "Upload is for a merge commit, updating commit id for upload", extra=dict( - commit=upload_params.get("commit"), + commit=commitid, commit_message=git_commit_message, new_commit=new_commit_id, ), @@ -480,10 +489,17 @@ def determine_upload_commit_to_use(upload_params, repository): return new_commit_id # If it's not a merge commit we'll just use the commitid provided in the upload parameters - return upload_params.get("commit") + return commitid -def insert_commit(commitid, branch, pr, repository, owner, parent_commit_id=None): +def insert_commit( + commitid: str, + branch: str, + pr: int, + repository: Repository, + owner: Owner, + parent_commit_id: Optional[str] = None, +) -> Commit: commit, was_created = Commit.objects.defer("_report").get_or_create( commitid=commitid, repository=repository, @@ -509,7 +525,7 @@ def insert_commit(commitid, branch, pr, repository, owner, parent_commit_id=None return commit -def get_global_tokens(): +def get_global_tokens() -> Dict[str, Any]: """ Enterprise only: check the config to see if global tokens were set for this organization's uploads. @@ -523,7 +539,7 @@ def get_global_tokens(): return tokens -def check_commit_upload_constraints(commit: Commit): +def check_commit_upload_constraints(commit: Commit) -> None: if settings.UPLOAD_THROTTLING_ENABLED and commit.repository.private: owner = _determine_responsible_owner(commit.repository) plan_service = PlanService(current_org=owner) @@ -545,7 +561,9 @@ def check_commit_upload_constraints(commit: Commit): raise Throttled(detail=message) -def validate_upload(upload_params, repository, redis): +def validate_upload( + upload_params: Dict[str, Any], repository: Repository, redis: Redis +) -> None: """ Make sure the upload can proceed and, if so, activate the repository if needed. """ @@ -645,7 +663,7 @@ def validate_upload(upload_params, repository, redis): ) -def _determine_responsible_owner(repository): +def _determine_responsible_owner(repository: Repository) -> Owner: owner = repository.author if owner.service == "gitlab": @@ -657,7 +675,9 @@ def _determine_responsible_owner(repository): return owner -def parse_headers(headers, upload_params): +def parse_headers( + headers: Dict[str, Any], upload_params: Dict[str, Any] +) -> Dict[str, Any]: version = upload_params.get("version") # Content disposition header @@ -671,8 +691,8 @@ def parse_headers(headers, upload_params): else: content_type = ( "text/plain" - if headers.get("X_Content_Type") in (None, "text/html") - else headers.get("X_Content_Type") + if headers.get("X_Content_Type", "") in ("", "text/html") + else headers.get("X_Content_Type", "") ) reduced_redundancy = ( False @@ -688,11 +708,11 @@ def parse_headers(headers, upload_params): def dispatch_upload_task( - task_arguments, - repository, - redis, - report_type=CommitReport.ReportType.COVERAGE, -): + task_arguments: Dict[str, Any], + repository: Repository, + redis: Redis, + report_type: Optional[CommitReport.ReportType] = CommitReport.ReportType.COVERAGE, +) -> None: # Store task arguments in redis cache_uploads_eta = get_config(("setup", "cache", "uploads"), default=86400) if report_type == CommitReport.ReportType.COVERAGE: @@ -741,7 +761,7 @@ def dispatch_upload_task( ) -def validate_activated_repo(repository): +def validate_activated_repo(repository: Repository) -> None: if repository.active and not repository.activated: config_url = f"{settings.CODECOV_DASHBOARD_URL}/{repository.author.service}/{repository.author.username}/{repository.name}/config/general" raise ValidationError( @@ -750,7 +770,7 @@ def validate_activated_repo(repository): # headers["User-Agent"] should look something like this: codecov-cli/0.4.7 or codecov-uploader/0.7.1 -def get_agent_from_headers(headers): +def get_agent_from_headers(headers: Dict[str, Any]) -> str: try: return headers["User-Agent"].split("/")[0].split("-")[1] except Exception as e: @@ -763,7 +783,7 @@ def get_agent_from_headers(headers): return "unknown-user-agent" -def get_version_from_headers(headers): +def get_version_from_headers(headers: Dict[str, Any]) -> str: try: return headers["User-Agent"].split("/")[1] except Exception as e: @@ -777,15 +797,15 @@ def get_version_from_headers(headers): def generate_upload_prometheus_metrics_labels( - action, - request, - is_shelter_request, + action: str, + request: HttpRequest, + is_shelter_request: bool, endpoint: Optional[str] = None, repository: Optional[Repository] = None, position: Optional[str] = None, upload_version: Optional[str] = None, include_empty_labels: bool = True, -): +) -> Dict[str, Any]: metrics_tags = dict( agent=get_agent_from_headers(request.headers), version=get_version_from_headers(request.headers), diff --git a/upload/tokenless/appveyor.py b/upload/tokenless/appveyor.py index 84893cc8b1..9b7d56855f 100644 --- a/upload/tokenless/appveyor.py +++ b/upload/tokenless/appveyor.py @@ -1,4 +1,5 @@ import logging +from typing import Any, Dict import requests from requests.exceptions import ConnectionError, HTTPError @@ -10,7 +11,7 @@ class TokenlessAppveyorHandler(BaseTokenlessUploadHandler): - def get_build(self): + def get_build(self) -> Dict[str, Any]: try: build = requests.get( "https://ci.appveyor.com/api/projects/{}/{}/build/{}".format( @@ -39,7 +40,7 @@ def get_build(self): return build.json() - def verify(self): + def verify(self) -> str: if not self.upload_params.get("job"): raise NotFound( 'Missing "job" argument. Please upload with the Codecov repository upload token to resolve issue.' @@ -60,7 +61,7 @@ def verify(self): # validate build if not any( filter( - lambda j: j["jobId"] == self.upload_params.get("build") + lambda j: j["jobId"] == self.upload_params.get("build", "") # type: ignore and j.get("finished") is None, build["build"]["jobs"], ) diff --git a/upload/tokenless/cirrus.py b/upload/tokenless/cirrus.py index 1661cb8f8a..9dc5350225 100644 --- a/upload/tokenless/cirrus.py +++ b/upload/tokenless/cirrus.py @@ -1,5 +1,6 @@ import logging import time +from typing import Any, Dict import requests from requests.exceptions import ConnectionError, HTTPError @@ -11,7 +12,7 @@ class TokenlessCirrusHandler(BaseTokenlessUploadHandler): - def get_build(self): + def get_build(self) -> Dict[str, Any]: query = f"""{{ "query": "query ($buildId: ID!) {{ build(id: $buildId) {{ @@ -74,7 +75,7 @@ def get_build(self): return build - def verify(self): + def verify(self) -> str: if not self.upload_params.get("owner"): raise NotFound( 'Missing "owner" argument. Please upload with the Codecov repository upload token to resolve this issue.' diff --git a/upload/views/base.py b/upload/views/base.py index 57facb7f60..f250379f32 100644 --- a/upload/views/base.py +++ b/upload/views/base.py @@ -1,4 +1,5 @@ import logging +from typing import Optional from django.conf import settings from rest_framework.exceptions import ValidationError @@ -60,7 +61,11 @@ def get_commit(self, repo: Repository) -> Commit: raise ValidationError("Commit SHA not found") def get_report( - self, commit: Commit, report_type=CommitReport.ReportType.COVERAGE + self, + commit: Commit, + report_type: Optional[ + CommitReport.ReportType + ] = CommitReport.ReportType.COVERAGE, ) -> CommitReport: report_code = self.kwargs.get("report_code") if report_code == "default": diff --git a/utils/test_utils.py b/utils/test_utils.py index 43893a7474..6cac27e04e 100644 --- a/utils/test_utils.py +++ b/utils/test_utils.py @@ -1,3 +1,5 @@ +from typing import Any + from django.apps import apps from django.db import connection from django.db.migrations.executor import MigrationExecutor @@ -13,17 +15,17 @@ class BaseTestCase(object): class ClientMixin: - def force_login_owner(self, owner: Owner): + def force_login_owner(self, owner: Owner) -> None: self.force_login(user=owner.user) session = self.session session["current_owner_id"] = owner.pk session.save() - def logout(self): + def logout(self) -> None: session = self.session session["current_owner_id"] = None session.save() - super().logout() + super().logout() # type: ignore class Client(ClientMixin, DjangoClient): @@ -36,13 +38,13 @@ class APIClient(ClientMixin, DjangoAPIClient): class TestMigrations(TestCase): @property - def app(self): + def app(self) -> str: return apps.get_containing_app_config(type(self).__module__).name migrate_from = None migrate_to = None - def setUp(self): + def setUp(self) -> None: assert self.migrate_from and self.migrate_to, ( "TestCase '{}' must define migrate_from and migrate_to properties".format( type(self).__name__ @@ -65,5 +67,5 @@ def setUp(self): self.apps = executor.loader.project_state(self.migrate_to).apps - def setUpBeforeMigration(self, apps): + def setUpBeforeMigration(self, apps: Any) -> None: pass