Skip to content

Commit

Permalink
fix(wms): correctly log the pilot job reference during the matching p…
Browse files Browse the repository at this point in the history
…rocess
  • Loading branch information
aldbr committed Dec 3, 2024
1 parent 6b374aa commit de9d072
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 90 deletions.
173 changes: 84 additions & 89 deletions src/DIRAC/WorkloadManagementSystem/Client/Matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,19 @@
"""
import time

from DIRAC import gLogger, convertToPy3VersionNumber

from DIRAC.Core.Utilities.PrettyPrint import printDict
from DIRAC.Core.Security import Properties
from DIRAC import convertToPy3VersionNumber, gLogger
from DIRAC.ConfigurationSystem.Client.Helpers import Registry
from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations
from DIRAC.WorkloadManagementSystem.Client import JobStatus
from DIRAC.Core.Security import Properties
from DIRAC.Core.Utilities.PrettyPrint import printDict
from DIRAC.ResourceStatusSystem.Client.SiteStatus import SiteStatus
from DIRAC.WorkloadManagementSystem.Client import JobStatus, PilotStatus
from DIRAC.WorkloadManagementSystem.Client.Limiter import Limiter
from DIRAC.WorkloadManagementSystem.Client import PilotStatus
from DIRAC.WorkloadManagementSystem.DB.TaskQueueDB import TaskQueueDB, singleValueDefFields, multiValueMatchFields
from DIRAC.WorkloadManagementSystem.DB.PilotAgentsDB import PilotAgentsDB
from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB
from DIRAC.WorkloadManagementSystem.DB.JobLoggingDB import JobLoggingDB
from DIRAC.ResourceStatusSystem.Client.SiteStatus import SiteStatus
from DIRAC.WorkloadManagementSystem.DB.PilotAgentsDB import PilotAgentsDB
from DIRAC.WorkloadManagementSystem.DB.TaskQueueDB import TaskQueueDB, multiValueMatchFields, singleValueDefFields
from DIRAC.WorkloadManagementSystem.Utilities.ContextVars import setPilotRefLogger


class PilotVersionError(Exception):
Expand Down Expand Up @@ -52,11 +51,7 @@ def __init__(self, pilotAgentsDB=None, jobDB=None, tqDB=None, jlDB=None, opsHelp
self.opsHelper = Operations()

if pilotRef:
self.log = gLogger.getSubLogger(f"[{pilotRef}]Matcher")
self.pilotAgentsDB.log = gLogger.getSubLogger(f"[{pilotRef}]Matcher")
self.jobDB.log = gLogger.getSubLogger(f"[{pilotRef}]Matcher")
self.tqDB.log = gLogger.getSubLogger(f"[{pilotRef}]Matcher")
self.jlDB.log = gLogger.getSubLogger(f"[{pilotRef}]Matcher")
self.log = gLogger.getLocalSubLogger(f"[{pilotRef}]Matcher")
else:
self.log = gLogger.getSubLogger("Matcher")

Expand All @@ -66,86 +61,86 @@ def __init__(self, pilotAgentsDB=None, jobDB=None, tqDB=None, jlDB=None, opsHelp

def selectJob(self, resourceDescription, credDict):
"""Main job selection function to find the highest priority job matching the resource capacity"""
with setPilotRefLogger(self.log):
startTime = time.time()

resourceDict = self._getResourceDict(resourceDescription, credDict)

# Make a nice print of the resource matching parameters
toPrintDict = dict(resourceDict)
if "MaxRAM" in resourceDescription:
toPrintDict["MaxRAM"] = resourceDescription["MaxRAM"]
if "NumberOfProcessors" in resourceDescription:
toPrintDict["NumberOfProcessors"] = resourceDescription["NumberOfProcessors"]
toPrintDict["Tag"] = []
if "Tag" in resourceDict:
for tag in resourceDict["Tag"]:
if not tag.endswith("GB") and not tag.endswith("Processors"):
toPrintDict["Tag"].append(tag)
if not toPrintDict["Tag"]:
toPrintDict.pop("Tag")
self.log.info("Resource description for matching", printDict(toPrintDict))

negativeCond = self.limiter.getNegativeCondForSite(resourceDict["Site"], resourceDict.get("GridCE"))
result = self.tqDB.matchAndGetJob(resourceDict, negativeCond=negativeCond)

startTime = time.time()

resourceDict = self._getResourceDict(resourceDescription, credDict)

# Make a nice print of the resource matching parameters
toPrintDict = dict(resourceDict)
if "MaxRAM" in resourceDescription:
toPrintDict["MaxRAM"] = resourceDescription["MaxRAM"]
if "NumberOfProcessors" in resourceDescription:
toPrintDict["NumberOfProcessors"] = resourceDescription["NumberOfProcessors"]
toPrintDict["Tag"] = []
if "Tag" in resourceDict:
for tag in resourceDict["Tag"]:
if not tag.endswith("GB") and not tag.endswith("Processors"):
toPrintDict["Tag"].append(tag)
if not toPrintDict["Tag"]:
toPrintDict.pop("Tag")
self.log.info("Resource description for matching", printDict(toPrintDict))

negativeCond = self.limiter.getNegativeCondForSite(resourceDict["Site"], resourceDict.get("GridCE"))
result = self.tqDB.matchAndGetJob(resourceDict, negativeCond=negativeCond)

if not result["OK"]:
raise RuntimeError(result["Message"])
result = result["Value"]
if not result["matchFound"]:
self.log.info("No match found")
return {}

jobID = result["jobId"]
resAtt = self.jobDB.getJobAttributes(jobID, ["OwnerDN", "OwnerGroup", "Status"])
if not resAtt["OK"]:
raise RuntimeError("Could not retrieve job attributes")
if not resAtt["Value"]:
raise RuntimeError("No attributes returned for job")
if not resAtt["Value"]["Status"] == "Waiting":
self.log.error("Job matched by the TQ is not in Waiting state", str(jobID))
result = self.tqDB.deleteJob(jobID)
if not result["OK"]:
raise RuntimeError(result["Message"])
raise RuntimeError(f"Job {str(jobID)} is not in Waiting state")
result = result["Value"]
if not result["matchFound"]:
self.log.info("No match found")
return {}

jobID = result["jobId"]
resAtt = self.jobDB.getJobAttributes(jobID, ["Status"])
if not resAtt["OK"]:
raise RuntimeError("Could not retrieve job attributes")
if not resAtt["Value"]:
raise RuntimeError("No attributes returned for job")
if not resAtt["Value"]["Status"] == "Waiting":
self.log.error("Job matched by the TQ is not in Waiting state", str(jobID))
result = self.tqDB.deleteJob(jobID)
if not result["OK"]:
raise RuntimeError(result["Message"])
raise RuntimeError(f"Job {str(jobID)} is not in Waiting state")

self._reportStatus(resourceDict, jobID)
self._reportStatus(resourceDict, jobID)

result = self.jobDB.getJobJDL(jobID)
if not result["OK"]:
raise RuntimeError("Failed to get the job JDL")

resultDict = {}
resultDict["JDL"] = result["Value"]
resultDict["JobID"] = jobID

matchTime = time.time() - startTime
self.log.verbose("Match time", f"[{str(matchTime)}]")

# Get some extra stuff into the response returned
resOpt = self.jobDB.getJobOptParameters(jobID)
if resOpt["OK"]:
for key, value in resOpt["Value"].items():
resultDict[key] = value
resAtt = self.jobDB.getJobAttributes(jobID, ["OwnerDN", "OwnerGroup"])
if not resAtt["OK"]:
raise RuntimeError("Could not retrieve job attributes")
if not resAtt["Value"]:
raise RuntimeError("No attributes returned for job")

if self.opsHelper.getValue("JobScheduling/CheckMatchingDelay", True):
self.limiter.updateDelayCounters(resourceDict["Site"], jobID)

pilotInfoReportedFlag = resourceDict.get("PilotInfoReportedFlag", False)
if not pilotInfoReportedFlag:
self._updatePilotInfo(resourceDict)
self._updatePilotJobMapping(resourceDict, jobID)

resultDict["DN"] = resAtt["Value"]["OwnerDN"]
resultDict["Group"] = resAtt["Value"]["OwnerGroup"]
resultDict["PilotInfoReportedFlag"] = True

return resultDict
result = self.jobDB.getJobJDL(jobID)
if not result["OK"]:
raise RuntimeError("Failed to get the job JDL")

resultDict = {}
resultDict["JDL"] = result["Value"]
resultDict["JobID"] = jobID

matchTime = time.time() - startTime
self.log.verbose("Match time", f"[{str(matchTime)}]")

# Get some extra stuff into the response returned
resOpt = self.jobDB.getJobOptParameters(jobID)
if resOpt["OK"]:
for key, value in resOpt["Value"].items():
resultDict[key] = value
resAtt = self.jobDB.getJobAttributes(jobID, ["Owner", "OwnerGroup"])
if not resAtt["OK"]:
raise RuntimeError("Could not retrieve job attributes")
if not resAtt["Value"]:
raise RuntimeError("No attributes returned for job")

if self.opsHelper.getValue("JobScheduling/CheckMatchingDelay", True):
self.limiter.updateDelayCounters(resourceDict["Site"], jobID)

pilotInfoReportedFlag = resourceDict.get("PilotInfoReportedFlag", False)
if not pilotInfoReportedFlag:
self._updatePilotInfo(resourceDict)
self._updatePilotJobMapping(resourceDict, jobID)

resultDict["Owner"] = resAtt["Value"]["Owner"]
resultDict["Group"] = resAtt["Value"]["OwnerGroup"]
resultDict["PilotInfoReportedFlag"] = True

return resultDict

def _getResourceDict(self, resourceDescription, credDict):
"""from resourceDescription to resourceDict (just various mods)"""
Expand Down
11 changes: 11 additions & 0 deletions src/DIRAC/WorkloadManagementSystem/DB/JobDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
extractJDL,
fixJDL,
)
from DIRAC.WorkloadManagementSystem.Utilities.ContextVars import pilotRefLogger


class JobDB(DB):
Expand All @@ -42,6 +43,8 @@ def __init__(self, parentLogger=None):

DB.__init__(self, "JobDB", "WorkloadManagement/JobDB", parentLogger=parentLogger)

self._defaultLogger = self.log

# data member to check if __init__ went through without error
self.__initialized = False
self.maxRescheduling = self.getCSOption("MaxRescheduling", 3)
Expand All @@ -64,6 +67,14 @@ def __init__(self, parentLogger=None):
self.log.info("==================================================")
self.__initialized = True

@property
def log(self):
return pilotRefLogger.get() or self._defaultLogger

@log.setter
def log(self, value):
self._defaultLogger = value

def isValid(self):
"""Check if correctly initialised"""
return self.__initialized
Expand Down
10 changes: 10 additions & 0 deletions src/DIRAC/WorkloadManagementSystem/DB/JobLoggingDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from DIRAC import S_ERROR, S_OK
from DIRAC.Core.Base.DB import DB
from DIRAC.Core.Utilities import TimeUtilities
from DIRAC.WorkloadManagementSystem.Utilities.ContextVars import pilotRefLogger

MAGIC_EPOC_NUMBER = 1270000000

Expand All @@ -24,6 +25,15 @@ def __init__(self, parentLogger=None):
"""Standard Constructor"""

DB.__init__(self, "JobLoggingDB", "WorkloadManagement/JobLoggingDB", parentLogger=parentLogger)
self._defaultLogger = self.log

@property
def log(self):
return pilotRefLogger.get() or self._defaultLogger

@log.setter
def log(self, value):
self._defaultLogger = value

#############################################################################
def addLoggingRecord(
Expand Down
10 changes: 10 additions & 0 deletions src/DIRAC/WorkloadManagementSystem/DB/PilotAgentsDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,23 @@
from DIRAC.Core.Utilities.MySQL import _quotedList
from DIRAC.ResourceStatusSystem.Client.SiteStatus import SiteStatus
from DIRAC.WorkloadManagementSystem.Client import PilotStatus
from DIRAC.WorkloadManagementSystem.Utilities.ContextVars import pilotRefLogger


class PilotAgentsDB(DB):
def __init__(self, parentLogger=None):
super().__init__("PilotAgentsDB", "WorkloadManagement/PilotAgentsDB", parentLogger=parentLogger)
self._defaultLogger = self.log
self.lock = threading.Lock()

@property
def log(self):
return pilotRefLogger.get() or self._defaultLogger

@log.setter
def log(self, value):
self._defaultLogger = value

##########################################################################################
def addPilotReferences(self, pilotRef, ownerGroup, gridType="DIRAC", pilotStampDict={}):
"""Add a new pilot job reference"""
Expand Down
10 changes: 10 additions & 0 deletions src/DIRAC/WorkloadManagementSystem/DB/TaskQueueDB.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from DIRAC.ConfigurationSystem.Client.Helpers.Operations import Operations
from DIRAC.ConfigurationSystem.Client.Helpers import Registry
from DIRAC.WorkloadManagementSystem.private.SharesCorrector import SharesCorrector
from DIRAC.WorkloadManagementSystem.Utilities.ContextVars import pilotRefLogger

DEFAULT_GROUP_SHARE = 1000
TQ_MIN_SHARE = 0.001
Expand All @@ -37,6 +38,7 @@ class TaskQueueDB(DB):

def __init__(self, parentLogger=None):
DB.__init__(self, "TaskQueueDB", "WorkloadManagement/TaskQueueDB", parentLogger=parentLogger)
self._defaultLogger = self.log
self.__maxJobsInTQ = 5000
self.__defaultCPUSegments = [
6 * 60,
Expand Down Expand Up @@ -64,6 +66,14 @@ def __init__(self, parentLogger=None):
if not result["OK"]:
raise Exception(f"Can't create tables: {result['Message']}")

@property
def log(self):
return pilotRefLogger.get() or self._defaultLogger

@log.setter
def log(self, value):
self._defaultLogger = value

def enableAllTaskQueues(self):
"""Enable all Task queues"""
return self.updateFields("tq_TaskQueues", updateDict={"Enabled": "1"})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def export_requestJob(self, resourceDescription):

resourceDescription["Setup"] = self.serviceInfoDict["clientSetup"]
credDict = self.getRemoteCredentials()
pilotRef = resourceDescription.get("PilotReference", "Unknown")
pilotRef = resourceDescription.get("PilotReference")

try:
opsHelper = Operations(group=credDict["group"])
Expand Down
16 changes: 16 additions & 0 deletions src/DIRAC/WorkloadManagementSystem/Utilities/ContextVars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
""" Context variables for the Workload Management System """

# Context variable for the logger (adapted to the request of the pilot reference)
import contextvars
from contextlib import contextmanager

pilotRefLogger = contextvars.ContextVar("PilotRefLogger", default=None)


@contextmanager
def setPilotRefLogger(logger_name):
token = pilotRefLogger.set(logger_name)
try:
yield
finally:
pilotRefLogger.reset(token)

0 comments on commit de9d072

Please sign in to comment.