From c5eb962aa4d5273734b885bcb1508dd18872d98b Mon Sep 17 00:00:00 2001 From: Adrian Galvan Date: Mon, 13 Jan 2025 17:52:16 -0800 Subject: [PATCH] Opening new session to complete task activities --- src/fides/api/task/execute_request_tasks.py | 123 ++++++++++++++------ 1 file changed, 88 insertions(+), 35 deletions(-) diff --git a/src/fides/api/task/execute_request_tasks.py b/src/fides/api/task/execute_request_tasks.py index c8a6139866..674252e712 100644 --- a/src/fides/api/task/execute_request_tasks.py +++ b/src/fides/api/task/execute_request_tasks.py @@ -2,7 +2,15 @@ from celery.app.task import Task from loguru import logger +from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Query, Session +from tenacity import ( + RetryCallState, + retry, + retry_if_exception_type, + stop_after_attempt, + wait_exponential, +) from fides.api.common_exceptions import ( PrivacyRequestCanceled, @@ -35,14 +43,13 @@ # DSR 3.0 task functions -def run_prerequisite_task_checks( +def get_privacy_request_and_task( session: Session, privacy_request_id: str, privacy_request_task_id: str -) -> Tuple[PrivacyRequest, RequestTask, Query]: +) -> Tuple[PrivacyRequest, RequestTask]: """ - Upfront checks that run as soon as the RequestTask is executed by the worker. - - Returns resources for use in executing a task + Retrieves and validates a privacy request and its associated task """ + privacy_request: Optional[PrivacyRequest] = PrivacyRequest.get( db=session, object_id=privacy_request_id ) @@ -65,6 +72,22 @@ def run_prerequisite_task_checks( f"Request Task with id {privacy_request_task_id} not found for privacy request {privacy_request_id}" ) + return privacy_request, request_task + + +def run_prerequisite_task_checks( + session: Session, privacy_request_id: str, privacy_request_task_id: str +) -> Tuple[PrivacyRequest, RequestTask, Query]: + """ + Upfront checks that run as soon as the RequestTask is executed by the worker. + + Returns resources for use in executing a task + """ + + privacy_request, request_task = get_privacy_request_and_task( + session, privacy_request_id, privacy_request_task_id + ) + assert request_task # For mypy upstream_results: Query = request_task.upstream_tasks_objects(session) @@ -146,6 +169,43 @@ def can_run_task_body( return True +def log_retry_attempt(retry_state: RetryCallState) -> None: + """Log queue_downstream_tasks retry attempts.""" + + logger.warning( + "queue_downstream_tasks attempt {} failed. Retrying in {} seconds...", + retry_state.attempt_number, + retry_state.next_action.sleep, # type: ignore[union-attr] + ) + + +@retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=1), + retry=retry_if_exception_type(OperationalError), + before_sleep=log_retry_attempt, +) +def queue_downstream_tasks_with_retries( + database_task: DatabaseTask, + privacy_request_id: str, + privacy_request_task_id: str, + current_step: CurrentStep, + privacy_request_proceed: bool, +) -> None: + with database_task.get_new_session() as session: + privacy_request, request_task = get_privacy_request_and_task( + session, privacy_request_id, privacy_request_task_id + ) + log_task_complete(request_task) + queue_downstream_tasks( + session, + request_task, + privacy_request, + current_step, + privacy_request_proceed, + ) + + def queue_downstream_tasks( session: Session, request_task: RequestTask, @@ -233,16 +293,15 @@ def run_access_node( ] # Run the main access function graph_task.access_request(*upstream_access_data) - log_task_complete(request_task) - - with self.get_new_session() as session: - queue_downstream_tasks( - session, - request_task, - privacy_request, - CurrentStep.upload_access, - privacy_request_proceed, - ) + logger.info(f"Session ID - After get access data: {id(session)}") + + queue_downstream_tasks_with_retries( + self, + privacy_request_id, + privacy_request_task_id, + CurrentStep.upload_access, + privacy_request_proceed, + ) @celery_app.task(base=DatabaseTask, bind=True) @@ -285,16 +344,13 @@ def run_erasure_node( # Run the main erasure function! graph_task.erasure_request(retrieved_data) - log_task_complete(request_task) - - with self.get_new_session() as session: - queue_downstream_tasks( - session, - request_task, - privacy_request, - CurrentStep.finalize_erasure, - privacy_request_proceed, - ) + queue_downstream_tasks_with_retries( + self, + privacy_request_id, + privacy_request_task_id, + CurrentStep.finalize_erasure, + privacy_request_proceed, + ) @celery_app.task(base=DatabaseTask, bind=True) @@ -339,16 +395,13 @@ def run_consent_node( graph_task.consent_request(access_data[0] if access_data else {}) - log_task_complete(request_task) - - with self.get_new_session() as session: - queue_downstream_tasks( - session, - request_task, - privacy_request, - CurrentStep.finalize_consent, - privacy_request_proceed, - ) + queue_downstream_tasks_with_retries( + self, + privacy_request_id, + privacy_request_task_id, + CurrentStep.finalize_consent, + privacy_request_proceed, + ) def logger_method(request_task: RequestTask) -> Callable: