Skip to content

Commit

Permalink
feat: PushJobAgent can handle special JobWrappers
Browse files Browse the repository at this point in the history
  • Loading branch information
aldbr committed Dec 22, 2023
1 parent 9f510fe commit 3482261
Show file tree
Hide file tree
Showing 7 changed files with 689 additions and 246 deletions.
268 changes: 246 additions & 22 deletions src/DIRAC/WorkloadManagementSystem/Agent/PushJobAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,35 @@
"""

import hashlib
import json
import os
import random
import sys
from collections import defaultdict
import time

from DIRAC import S_OK, gConfig
from DIRAC.Core.Security import X509Chain
from DIRAC.Core.Utilities.ObjectLoader import ObjectLoader
from DIRAC.Core.Utilities import DErrno
from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations
from DIRAC.ConfigurationSystem.Client.Helpers.Registry import getDNForUsername
from DIRAC.Core.Utilities.ProcessPool import S_ERROR
from DIRAC.FrameworkSystem.Client.ProxyManagerClient import gProxyManager
from DIRAC.FrameworkSystem.private.standardLogging.LogLevels import LogLevel
from DIRAC.RequestManagementSystem.Client.Request import Request
from DIRAC.WorkloadManagementSystem.Client import JobStatus
from DIRAC.Resources.Computing import ComputingElement
from DIRAC.WorkloadManagementSystem.Client import JobStatus, PilotStatus
from DIRAC.WorkloadManagementSystem.JobWrapper.JobWrapperUtilities import (
getJobWrapper,
resolveInputData,
transferInputSandbox,
)
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved
from DIRAC.WorkloadManagementSystem.Service.WMSUtilities import getGridEnv
from DIRAC.WorkloadManagementSystem.Agent.JobAgent import JobAgent
from DIRAC.WorkloadManagementSystem.Utilities.Utils import createLightJobWrapper
from DIRAC.WorkloadManagementSystem.private.ConfigHelper import findGenericPilotCredentials

MAX_JOBS_MANAGED = 100
Expand All @@ -50,10 +64,24 @@ def __init__(self, agentName, loadName, baseAgentName=False, properties=None):
self.failedQueueCycleFactor = 10
self.failedQueues = defaultdict(int)

# Choose the submission policy
# - Workflow: the agent will submit a workflow to a PoolCE, the workflow is responsible for interacting with the remote site
# - JobWrapper: the agent will submit a JobWrapper directly to the remote site, it is responsible of the remote execution
self.submissionPolicy = "Workflow"

# cleanTask is used to clean the task in the remote site
self.cleanTask = True

def initialize(self):
"""Sets default parameters and creates CE instance"""
super().initialize()

# Get the submission policy
# Initialized here because it cannot be dynamically modified during the execution
self.submissionPolicy = self.am_getOption("SubmissionPolicy", self.submissionPolicy)
if self.submissionPolicy not in ["Workflow", "JobWrapper"]:
return S_ERROR("SubmissionPolicy must be either Workflow or JobWrapper")

result = self._initializeComputingElement("Pool")
if not result["OK"]:
return result
Expand Down Expand Up @@ -87,6 +115,7 @@ def beginExecution(self):
self.computingElement.setParameters({"NumberOfProcessors": self.maxJobsToSubmit})

self.failedQueueCycleFactor = self.am_getOption("FailedQueueCycleFactor", self.failedQueueCycleFactor)
self.cleanTask = self.am_getOption("CleanTask", self.cleanTask)

# Get target queues from the configuration
siteNames = None
Expand Down Expand Up @@ -137,11 +166,12 @@ def execute(self):
if not result["OK"] or result["Value"]:
return result

# Check errors that could have occurred during job submission and/or execution
# Status are handled internally, and therefore, not checked outside of the method
result = self._checkSubmittedJobs()
if not result["OK"]:
return result
if self.submissionPolicy == "Workflow":
# Check errors that could have occurred during job submission and/or execution
# Status are handled internally, and therefore, not checked outside of the method
result = self._checkSubmittedJobs()
if not result["OK"]:
return result

for queueName, queueDictionary in queueDictItems:
# Make sure there is no problem with the queue before trying to submit
Expand Down Expand Up @@ -251,21 +281,41 @@ def execute(self):

# Submit the job to the CE
self.log.debug(f"Before self._submitJob() ({self.ceName}CE)")
resultSubmission = self._submitJob(
jobID=jobID,
jobParams=params,
resourceParams=ceDict,
optimizerParams=optimizerParams,
proxyChain=proxyChain,
processors=submissionParams["processors"],
wholeNode=submissionParams["wholeNode"],
maxNumberOfProcessors=submissionParams["maxNumberOfProcessors"],
mpTag=submissionParams["mpTag"],
)
if not result["OK"]:
result = self._rescheduleFailedJob(jobID, resultSubmission["Message"])
self.failedQueues[queueName] += 1
break
if self.submissionPolicy == "Workflow":
resultSubmission = self._submitJob(
jobID=jobID,
jobParams=params,
resourceParams=ceDict,
optimizerParams=optimizerParams,
proxyChain=proxyChain,
processors=submissionParams["processors"],
wholeNode=submissionParams["wholeNode"],
maxNumberOfProcessors=submissionParams["maxNumberOfProcessors"],
mpTag=submissionParams["mpTag"],
)
if not result["OK"]:
result = self._rescheduleFailedJob(jobID, resultSubmission["Message"])
self.failedQueues[queueName] += 1
break
else:
resultSubmission = self._submitJobWrapper(
jobID=jobID,
ce=ce,
jobParams=params,
resourceParams=ceDict,
optimizerParams=optimizerParams,
proxyChain=proxyChain,
processors=submissionParams["processors"],
)
if not result["OK"]:
self.failedQueues[queueName] += 1
break

# Check status of the submitted jobs
result = self._checkSubmittedJobWrapper(ce)
if not result["OK"]:
self.failedQueues[queueName] += 1
break
self.log.debug(f"After {self.ceName}CE submitJob()")

# Committing the JobReport before evaluating the result of job submission
Expand Down Expand Up @@ -368,7 +418,8 @@ def _setCEDict(self, ceDict):
ceDict["ReleaseProject"] = project

# Add a RemoteExecution entry, which can be used in the next stages
ceDict["RemoteExecution"] = True
if self.submissionPolicy == "Workflow":
ceDict["RemoteExecution"] = True

def _checkMatchingIssues(self, jobRequest):
"""Check the source of the matching issue
Expand All @@ -384,3 +435,176 @@ def _checkMatchingIssues(self, jobRequest):
self.log.notice("Failed to get jobs", jobRequest["Message"])

return S_OK()

def _submitJobWrapper(
self,
jobID: str,
ce: ComputingElement,
jobParams: dict,
resourceParams: dict,
optimizerParams: dict,
proxyChain: X509Chain,
processors: int,
):
"""Submit a JobWrapper to the remote site
:param jobID: job ID
:param ce: ComputingElement instance
:param jobParams: job parameters
:param resourceParams: resource parameters
:param optimizerParams: optimizer parameters
:param proxyChain: proxy chain
:param processors: number of processors
:return: S_OK
"""
# Add the number of requested processors to the job environment
if "ExecutionEnvironment" in jobParams:
if isinstance(jobParams["ExecutionEnvironment"], str):
jobParams["ExecutionEnvironment"] = jobParams["ExecutionEnvironment"].split(";")
jobParams.setdefault("ExecutionEnvironment", []).append("DIRAC_JOB_PROCESSORS=%d" % processors)

# Prepare the job for submission
self.verbose("Getting a JobWrapper")
arguments = {"Job": jobParams, "CE": resourceParams, "Optimizer": optimizerParams}
job = getJobWrapper(jobID, arguments, self.jobReport)
if not job:
return S_ERROR(f"Cannot get a JobWrapper instance for job {jobID}")

if "InputSandbox" in jobParams:
self.log.verbose("Getting the inputSandbox of the job")
if not transferInputSandbox(job, jobParams["InputSandbox"], self.jobReport):
return S_ERROR(f"Cannot get input sandbox of job {jobID}")
self.jobReport.commit()

if "InputData" in jobParams and jobParams["InputData"]:
self.log.verbose("Getting the inputData of the job")
if not resolveInputData(job, self.jobReport):
return S_ERROR(f"Cannot get input data of job {jobID}")
self.jobReport.commit()

# Preprocess the payload
payloadParams = job.preProcess()
self.jobReport.commit()

# Generate a light JobWrapper executor script
jobDesc = {
"jobID": jobID,
"jobParams": jobParams,
"resourceParams": resourceParams,
"optimizerParams": optimizerParams,
"payloadParams": payloadParams,
"extraOptions": self.extraOptions,
}
result = createLightJobWrapper(log=self.log, logLevel=self.logLevel, **jobDesc)
if not result["OK"]:
return result
wrapperFile = result["Value"][0]

# Get inputs from the current working directory
inputs = os.listdir(".")
inputs.remove(os.path.basename(wrapperFile))
self.log.verbose("The executable will be sent along with the following inputs:", ",".join(inputs))

# Request the whole directory as output
outputs = ["/"]

self.jobReport.setJobStatus(minorStatus="Submitting To CE")
self.log.info("Submitting JobWrapper", f"{os.path.basename(wrapperFile)} to {self.ceName}CE")

# Pass proxy to the CE
proxy = proxyChain.dumpAllToString()
if not proxy["OK"]:
self.log.error("Invalid proxy", proxy)
return S_ERROR("Payload Proxy Not Found")
ce.setProxy(proxy["Value"])

result = ce.submitJob(
executableFile=wrapperFile,
proxy=None,
inputs=inputs,
outputs=outputs,
)
if not result["OK"]:
self._rescheduleFailedJob(jobID, result["Message"])
return result

taskID = result["Value"][0]
stamp = result["PilotStampDict"][taskID]
self.log.info("Job being submitted", f"(DIRAC JobID: {jobID}; Task ID: {taskID})")

self.submissionDict[f"{taskID}:::{stamp}"] = job
time.sleep(self.jobSubmissionDelay)
return S_OK()

def _checkOutputIntegrity(self, workingDirectory: str):
"""Make sure that output files are not corrupted.
:param workingDirectory: path of the outputs
"""
checkSumOutput = os.path.join(workingDirectory, "checksums.json")
if not os.path.exists(checkSumOutput):
return S_ERROR(f"Cannot guarantee the integrity of the outputs: {checkSumOutput} unavailable")

with open(checkSumOutput) as f:
checksums = json.load(f)

# for each output file, compute the md5 checksum
for output, checksum in checksums.items():
hash = hashlib.md5()
localOutput = os.path.join(workingDirectory, output)
if not os.path.exists(localOutput):
return S_ERROR(f"{localOutput} was expected but not found")

with open(localOutput, "rb") as f:
while chunk := f.read(128 * hash.block_size):
hash.update(chunk)
if checksum != hash.hexdigest():
return S_ERROR(f"{localOutput} is corrupted")

return S_OK()

def _checkSubmittedJobWrapper(self, ce: ComputingElement):
"""Check the status of the submitted tasks.
If the task is finished, get the output and post process the job.
Finally, remove from the submission dictionary.
:return: S_OK/S_ERROR
"""
if not (result := ce.getJobStatus(self.submissionDict.keys()))["OK"]:
self.log.error("Failed to get job status", result["Message"])
return result

for taskID, status in result["Value"]:
if status not in PilotStatus.PILOT_FINAL_STATES:
continue

self.log.info("Job execution finished", f"(DIRAC taskID: {taskID}; Status: {status})")

# Get the output of the job
self.log.info(f"Getting the outputs of taskID {taskID}")
if not (result := ce.getJobOutput(taskID, os.path.abspath(".")))["OK"]:
self.log.error("Failed to get the output of taskID", f"{taskID}: {result['Message']}")
return result

# Make sure the output is correct
self.log.info(f"Checking the integrity of the outputs of {taskID}")
if not (result := self._checkOutputIntegrity("."))["OK"]:
return result
self.log.info("The output has been retrieved and declared complete")

job = self.submissionDict[taskID]

with open("payloadResults.json") as f:
payloadResults = json.load(f)
job.postProcess(**payloadResults)

# Clean job in the remote resource
if self.cleanTask:
if not (result := ce.cleanJob(taskID))["OK"]:
self.log.warn("Failed to clean the output remotely", result["Message"])
self.log.info(f"TaskID {taskID} has been remotely removed")

# Remove the job from the submission dictionary
del self.submissionDict[taskID]
return S_OK()
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@

# DIRAC Components
from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations
from DIRAC.Core.Utilities.ReturnValues import S_OK
from DIRAC.Resources.Computing import ComputingElement
from DIRAC.WorkloadManagementSystem import JobWrapper
from DIRAC.WorkloadManagementSystem.Agent.PushJobAgent import PushJobAgent

from DIRAC import gLogger, S_ERROR
Expand Down
Loading

0 comments on commit 3482261

Please sign in to comment.