From 3761dc7d6fa17de7fa06f402042d44517e47dee3 Mon Sep 17 00:00:00 2001 From: aldbr Date: Tue, 2 Jun 2026 18:48:35 +0200 Subject: [PATCH 1/3] refactor: improve PilotStatusAgent performances and refactor duplicated code --- .../Computing/SSHBatchComputingElement.py | 10 ++ .../Computing/SSHComputingElement.py | 32 +++++ .../Agent/PilotStatusAgent.py | 47 ++++--- .../Agent/PushJobAgent.py | 4 +- .../Agent/SiteDirector.py | 6 +- .../Service/WMSUtilities.py | 51 +++++--- .../Service/test/Test_WMSUtilities.py | 67 ++++++++++ .../Utilities/QueueUtilities.py | 106 +++++++++++++--- .../Utilities/test/Test_QueueUtilities.py | 118 +++++++++++++++++- .../scripts/dirac_admin_debug_ce.py | 1 - .../scripts/dirac_wms_match.py | 2 +- 11 files changed, 386 insertions(+), 58 deletions(-) create mode 100644 src/DIRAC/WorkloadManagementSystem/Service/test/Test_WMSUtilities.py diff --git a/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py b/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py index 52fb59982b5..f1b996836f4 100644 --- a/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py +++ b/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py @@ -158,6 +158,16 @@ def killJob(self, jobIDs): return result + def shutdown(self): + """Close every per-host SSH connection (and gateway), releasing their + Paramiko Transport threads. Called when the CE is evicted/rebuilt by the + :class:`~DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities.QueueCECache`. + """ + for details in self.connections.values(): + self._closeConnection(details.get("connection")) + self.connections = {} + return S_OK() + ############################################################################# def getCEStatus(self): diff --git a/src/DIRAC/Resources/Computing/SSHComputingElement.py b/src/DIRAC/Resources/Computing/SSHComputingElement.py index 0ad353eb4a9..a2e9e3c7fb0 100644 --- a/src/DIRAC/Resources/Computing/SSHComputingElement.py +++ b/src/DIRAC/Resources/Computing/SSHComputingElement.py @@ -514,6 +514,38 @@ def killJob(self, jobIDList): jobIDList = [jobIDList] return self._killJobOnHost(self.connection, jobIDList) + def _closeConnection(self, connection): + """Close an SSH connection *and its gateway* (if any). + + ``fabric.Connection.close()`` closes the connection's own client and SFTP + sessions but NOT a gateway ``Connection``, which is opened separately as a + jump host. The gateway must therefore be closed explicitly, otherwise its + Paramiko ``Transport`` thread (and the resources it pins) leaks. This is a + no-op on a connection that was never opened. + """ + if connection is None: + return + gateway = getattr(connection, "gateway", None) + try: + connection.close() + except Exception as e: # a close failure must not break cache eviction + self.log.warn("Failed to close SSH connection", str(e)) + # The jump host (SSHTunnel) is a separate Connection that close() ignores + if isinstance(gateway, Connection): + try: + gateway.close() + except Exception as e: + self.log.warn("Failed to close SSH gateway connection", str(e)) + + def shutdown(self): + """Close the SSH connection (and its gateway), releasing the Paramiko + Transport thread. Called when the CE is evicted/rebuilt by the + :class:`~DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities.QueueCECache`. + """ + self._closeConnection(self.connection) + self.connection = None + return S_OK() + def _killJobOnHost(self, connection: Connection, jobIDList: list[str]): """Kill the jobs for the given list of job IDs""" batchSystemJobList = [] diff --git a/src/DIRAC/WorkloadManagementSystem/Agent/PilotStatusAgent.py b/src/DIRAC/WorkloadManagementSystem/Agent/PilotStatusAgent.py index 4a6a67c99b6..f4e3b9a7d91 100644 --- a/src/DIRAC/WorkloadManagementSystem/Agent/PilotStatusAgent.py +++ b/src/DIRAC/WorkloadManagementSystem/Agent/PilotStatusAgent.py @@ -9,17 +9,18 @@ """ import datetime -from DIRAC import S_OK, gConfig +from DIRAC import S_OK from DIRAC.AccountingSystem.Client.DataStoreClient import gDataStoreClient from DIRAC.AccountingSystem.Client.Types.Pilot import Pilot as PilotAccounting from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getCESiteMapping from DIRAC.Core.Base.AgentModule import AgentModule from DIRAC.Core.Utilities import TimeUtilities -from DIRAC.Interfaces.API.DiracAdmin import DiracAdmin from DIRAC.WorkloadManagementSystem.Client import PilotStatus from DIRAC.WorkloadManagementSystem.Client.PilotManagerClient import PilotManagerClient from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB from DIRAC.WorkloadManagementSystem.DB.PilotAgentsDB import PilotAgentsDB +from DIRAC.WorkloadManagementSystem.Service.WMSUtilities import killPilotsInQueues +from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import QueueCECache class PilotStatusAgent(AgentModule): @@ -39,14 +40,15 @@ def __init__(self, *args, **kwargs): self.jobDB = None self.pilotDB = None - self.diracadmin = None + # Cache of ComputingElement instances keyed by queue. + self.ceCache = None ############################################################################# def initialize(self): """Sets defaults""" self.pilotDB = PilotAgentsDB() - self.diracadmin = DiracAdmin() + self.ceCache = QueueCECache() self.jobDB = JobDB() self.clearPilotsDelay = self.am_getOption("ClearPilotsDelay", 30) self.clearAbortedDelay = self.am_getOption("ClearAbortedPilotsDelay", 7) @@ -200,16 +202,33 @@ def _addPilotsAccountingReport(self, pilotsData): return S_OK() def _killPilots(self, acc): - for i in sorted(acc.keys()): - result = self.diracadmin.getPilotInfo(i) - if result["OK"] and i in result["Value"] and "Status" in result["Value"][i]: - ret = self.diracadmin.killPilot(str(i)) - if ret["OK"]: - self.log.info("Successfully deleted", f": {i} (Status : {result['Value'][i]['Status']})") - else: - self.log.error("Failed to delete pilot: ", f"{i} : {ret['Message']}") - else: - self.log.error("Failed to get pilot info", f"{i} : {str(result)}") + """Declare the given pilots killed on their CEs. + + Pilots are grouped per queue and killed with a single call per queue, + reusing a cached CE/connection per queue across cycles (see + :func:`~DIRAC.WorkloadManagementSystem.Service.WMSUtilities.killPilotsInQueues`). + """ + # Group the pilots to kill per queue + pilotsByQueue = {} + for pRef in acc: + pilotDict = acc[pRef] + queueFields = [pilotDict["VO"], pilotDict["GridSite"], pilotDict["DestinationSite"], pilotDict["Queue"]] + # A pilot with an incomplete queue definition cannot be located on a + # CE; skip it rather than letting it abort the whole batch. + if not all(queueFields): + self.log.warn("Cannot determine queue for pilot, skipping kill", f"{pRef} : {queueFields}") + continue + queueKey = "@@@".join(queueFields) + queueData = pilotsByQueue.setdefault(queueKey, {"GridType": pilotDict["GridType"], "PilotList": []}) + queueData["PilotList"].append(pRef) + + if not pilotsByQueue: + return + + # The CEs (and their connections) are reused across cycles via self.ceCache. + result = killPilotsInQueues(pilotsByQueue, ceCache=self.ceCache) + if not result["OK"]: + self.log.error("Failed to kill some pilots", result["Message"]) def _checkJobLastUpdateTime(self, joblist, StalledDays): timeLimitToConsider = datetime.datetime.utcnow() - TimeUtilities.day * StalledDays diff --git a/src/DIRAC/WorkloadManagementSystem/Agent/PushJobAgent.py b/src/DIRAC/WorkloadManagementSystem/Agent/PushJobAgent.py index 73965c3071c..24482430017 100644 --- a/src/DIRAC/WorkloadManagementSystem/Agent/PushJobAgent.py +++ b/src/DIRAC/WorkloadManagementSystem/Agent/PushJobAgent.py @@ -44,7 +44,7 @@ transferInputSandbox, ) from DIRAC.WorkloadManagementSystem.private.ConfigHelper import findGenericPilotCredentials -from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved +from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved, QueueCECache from DIRAC.WorkloadManagementSystem.Utilities.Utils import createJobWrapper MAX_JOBS_MANAGED = 100 @@ -61,7 +61,7 @@ def __init__(self, agentName, loadName, baseAgentName=False, properties=None): self.firstPass = True self.maxJobsToSubmit = MAX_JOBS_MANAGED self.queueDict = {} - self.queueCECache = {} + self.queueCECache = QueueCECache() self.pilotDN = "" self.vo = "" diff --git a/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py b/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py index 418413adc84..3519ef3b0b5 100644 --- a/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py +++ b/src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py @@ -43,7 +43,7 @@ getPilotFilesCompressedEncodedDict, pilotWrapperScript, ) -from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved +from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved, QueueCECache MAX_PILOTS_TO_SUBMIT = 100 @@ -60,7 +60,7 @@ def __init__(self, *args, **kwargs): self.queueDict = {} # self.queueCECache aims at saving CEs information over the cycles to avoid to create the exact same CEs each cycle - self.queueCECache = {} + self.queueCECache = QueueCECache() self.failedQueues = defaultdict(int) self.maxPilotsToSubmit = MAX_PILOTS_TO_SUBMIT @@ -575,7 +575,7 @@ def _getExecutable(self, queue: str, proxy: X509Chain, jobExecDir: str = "", env # in your machine, the executable files will be in the same place # but it does not matter since they are very temporary - ce = self.queueCECache[queue]["CE"] + ce = self.queueDict[queue]["CE"] workingDirectory = getattr(ce, "workingDirectory", self.workingDirectory) executable = self._writePilotScript( diff --git a/src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py b/src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py index a8b216bc4c8..067078f6540 100644 --- a/src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py +++ b/src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py @@ -14,6 +14,7 @@ from DIRAC.FrameworkSystem.Client.TokenManagerClient import gTokenManager from DIRAC.Resources.Computing.ComputingElementFactory import ComputingElementFactory from DIRAC.WorkloadManagementSystem.Client.PilotScopes import PILOT_SCOPES +from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import QueueCECache # List of files to be inserted/retrieved into/from pilot Output Sandbox # first will be defined as StdOut in JDL and the second as StdErr @@ -101,34 +102,52 @@ def getPilotRef(pilotReference, pilotDict): return S_OK(pRef) -def killPilotsInQueues(pilotRefDict): - """kill pilots queue by queue +def killPilotsInQueues(pilotRefDict, ceCache=None): + """Kill pilots queue by queue. - :params dict pilotRefDict: a dict of pilots in queues - """ + Pilots are grouped per queue in ``pilotRefDict`` (key + ``"@@@@@@@@@"``) and killed with a single call per queue. + Every queue is attempted: failures are collected and reported together, so a + single unreachable CE no longer prevents pilots on the other queues from being + killed (previous versions failed fast on the first error). - ceFactory = ComputingElementFactory() + :param dict pilotRefDict: ``{queueKey: {"GridType": str, "PilotList": [ref, ...]}}`` + :param ceCache: optional :class:`~DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities.QueueCECache` + whose CEs (and their connections) are reused across calls. When omitted, a + transient cache is created for this call only. + :return: S_OK() if all queues succeeded, otherwise S_ERROR aggregating the failures + """ + if ceCache is None: + ceCache = QueueCECache() + failures = {} for key, pilotDict in pilotRefDict.items(): vo, site, ce, queue = key.split("@@@") + result = getQueue(site, ce, queue) if not result["OK"]: - return result - queueDict = result["Value"] - gridType = pilotDict["GridType"] - result = ceFactory.getCE(gridType, ce, queueDict) + failures[key] = result["Message"] + continue + + result = ceCache.getCE(key, pilotDict["GridType"], ce, result["Value"]) if not result["OK"]: - return result - ce = result["Value"] + failures[key] = result["Message"] + continue + computingElement = result["Value"] pilotDict["VO"] = vo - result = setPilotCredentials(ce, pilotDict) + result = setPilotCredentials(computingElement, pilotDict) if not result["OK"]: - return result + failures[key] = result["Message"] + continue - pilotList = pilotDict["PilotList"] - result = ce.killJob(pilotList) + result = computingElement.killJob(pilotDict["PilotList"]) if not result["OK"]: - return result + # Drop the (possibly stale) CE so it is rebuilt on the next call + ceCache.drop(key) + failures[key] = result["Message"] + continue + if failures: + return S_ERROR(f"Failed to kill pilots in queues: {failures}") return S_OK() diff --git a/src/DIRAC/WorkloadManagementSystem/Service/test/Test_WMSUtilities.py b/src/DIRAC/WorkloadManagementSystem/Service/test/Test_WMSUtilities.py new file mode 100644 index 00000000000..de1992b1a4d --- /dev/null +++ b/src/DIRAC/WorkloadManagementSystem/Service/test/Test_WMSUtilities.py @@ -0,0 +1,67 @@ +""" Test class for WMSUtilities +""" +from unittest.mock import MagicMock + +from DIRAC import S_OK, S_ERROR +from DIRAC.WorkloadManagementSystem.Service.WMSUtilities import killPilotsInQueues +from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import QueueCECache + +# The CE factory is reached through QueueCECache; getQueue and setPilotCredentials +# are module-level names looked up inside killPilotsInQueues, so we patch them there. +GET_CE = "DIRAC.Resources.Computing.ComputingElementFactory.ComputingElementFactory.getCE" +GET_QUEUE = "DIRAC.WorkloadManagementSystem.Service.WMSUtilities.getQueue" +SET_CREDS = "DIRAC.WorkloadManagementSystem.Service.WMSUtilities.setPilotCredentials" + + +def _twoQueues(): + return { + "vo@@@site1@@@ce1@@@queue1": {"GridType": "ssh", "PilotList": ["p1", "p2"]}, + "vo@@@site2@@@ce2@@@queue2": {"GridType": "ssh", "PilotList": ["p3"]}, + } + + +def test_killPilotsInQueues_allSucceed(mocker): + """One killJob call per queue when everything succeeds.""" + mocker.patch(GET_QUEUE, return_value=S_OK({"CEType": "ssh"})) + mocker.patch(SET_CREDS, return_value=S_OK()) + ce = MagicMock() + ce.killJob = MagicMock(return_value=S_OK()) + mocker.patch(GET_CE, return_value=S_OK(ce)) + + result = killPilotsInQueues(_twoQueues()) + + assert result["OK"] + assert ce.killJob.call_count == 2 # one call per queue + + +def test_killPilotsInQueues_attemptsAllAndAggregates(mocker): + """A queue whose killJob fails must not stop the others from being attempted.""" + mocker.patch(GET_QUEUE, return_value=S_OK({"CEType": "ssh"})) + mocker.patch(SET_CREDS, return_value=S_OK()) + ce = MagicMock() + # First queue fails to kill, second succeeds (dict preserves insertion order) + ce.killJob = MagicMock(side_effect=[S_ERROR("boom"), S_OK()]) + mocker.patch(GET_CE, return_value=S_OK(ce)) + + result = killPilotsInQueues(_twoQueues()) + + assert not result["OK"] # the failure is reported... + assert ce.killJob.call_count == 2 # ...but every queue was still attempted + + +def test_killPilotsInQueues_reusesProvidedCache(mocker): + """A provided QueueCECache reuses the CE across calls instead of rebuilding it.""" + mocker.patch(GET_QUEUE, return_value=S_OK({"CEType": "ssh"})) + mocker.patch(SET_CREDS, return_value=S_OK()) + ce = MagicMock() + ce.killJob = MagicMock(return_value=S_OK()) + getCEMock = mocker.patch(GET_CE, return_value=S_OK(ce)) + + cache = QueueCECache() + refDict = {"vo@@@site1@@@ce1@@@queue1": {"GridType": "ssh", "PilotList": ["p1"]}} + + assert killPilotsInQueues(refDict, ceCache=cache)["OK"] + assert killPilotsInQueues(refDict, ceCache=cache)["OK"] + + # Same queue, unchanged parameters -> the CE is built once and reused on the 2nd call + getCEMock.assert_called_once() diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py b/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py index 504d79cdac4..34fc594a902 100644 --- a/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py @@ -2,7 +2,7 @@ import hashlib -from DIRAC import S_OK, S_ERROR +from DIRAC import S_OK, S_ERROR, gLogger from DIRAC.Core.Utilities.List import fromChar from DIRAC.Core.Utilities.ClassAd.ClassAdLight import ClassAd from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getDIRACPlatform @@ -11,12 +11,26 @@ from DIRAC.Resources.Computing.ComputingElementFactory import ComputingElementFactory -def getQueuesResolved(siteDict, queueCECache, vo=None, checkPlatform=False, instantiateCEs=False): +def getQueuesResolved(siteDict, queueCECache=None, vo=None, checkPlatform=False, instantiateCEs=False): """Get the list of relevant CEs (what is in siteDict) and their descriptions. The main goal of this method is to return a dictionary of queues + + :param dict siteDict: sites/CEs/queues structure as returned by ``getQueues`` + :param queueCECache: a :class:`QueueCECache` used to reuse CE instances across + cycles when ``instantiateCEs`` is set. For backward compatibility a plain + ``dict`` (the legacy cache format) is also accepted and adopted as the cache + backing store; if omitted, a fresh cache is used. + :param str vo: VO name + :param bool checkPlatform: resolve the queue platform + :param bool instantiateCEs: instantiate (and cache) a CE object per queue + :return: S_OK(queueDict)/S_ERROR """ + # Backward compatibility: callers historically passed a plain dict (or nothing). + # Adopt it as the cache backing store so they keep cross-cycle reuse unchanged. + if not isinstance(queueCECache, QueueCECache): + queueCECache = QueueCECache(queueCECache if isinstance(queueCECache, dict) else None) + queueDict = {} - ceFactory = ComputingElementFactory() for site in siteDict: for ce in siteDict[site]: @@ -50,27 +64,20 @@ def getQueuesResolved(siteDict, queueCECache, vo=None, checkPlatform=False, inst ceQueueDict.update(queueDict[queueName]["ParametersDict"]) if instantiateCEs: - # Generate the CE object for the queue or pick the already existing one - # if the queue definition did not change - queueHash = generateQueueHash(ceQueueDict) - if queueName in queueCECache and queueCECache[queueName]["Hash"] == queueHash: - queueCE = queueCECache[queueName]["CE"] - else: - result = ceFactory.getCE(ceName=ce, ceType=ceDict["CEType"], ceParametersDict=ceQueueDict) - if not result["OK"]: - queueDict.pop(queueName) - continue - queueCECache.setdefault(queueName, {}) - queueCECache[queueName]["Hash"] = queueHash - queueCECache[queueName]["CE"] = result["Value"] - queueCE = queueCECache[queueName]["CE"] + # Get the CE object for the queue, reusing the cached one if the + # queue definition did not change, or (re)building it otherwise. + result = queueCECache.getCE(queueName, ceDict["CEType"], ce, ceQueueDict) + if not result["OK"]: + queueDict.pop(queueName) + continue + queueCE = result["Value"] queueDict[queueName]["ParametersDict"].update(queueCE.ceParameters) queueDict[queueName]["CE"] = queueCE - result = queueDict[queueName]["CE"].isValid() + result = queueCE.isValid() if not result["OK"]: queueDict.pop(queueName) - queueCECache.pop(queueName) + queueCECache.drop(queueName) continue queueDict[queueName]["CEName"] = ce @@ -141,6 +148,67 @@ def generateQueueHash(queueDict): return hexstring +class QueueCECache: + """A cache of ComputingElement instances keyed by queue. + + CEs -- and, for connection-based CEs such as the SSHComputingElement, their + underlying connections -- are reused across cycles instead of being + re-created on every use. A CE is rebuilt only when its queue parameters + change, detected through a hash of the parameters dictionary: the same + invalidation strategy used by the SiteDirector (see :func:`getQueuesResolved`). + """ + + def __init__(self, backing=None): + """:param dict backing: optional pre-existing ``queueKey -> {"Hash", "CE"}`` dict + to adopt as the cache store. Used to stay backward compatible with callers + that historically passed (and held onto) a plain dict cache; mutating it in + place preserves their cross-cycle reuse. A fresh dict is used when omitted. + """ + # queueKey -> {"Hash": , "CE": } + self._cache = backing if backing is not None else {} + self._ceFactory = ComputingElementFactory() + self.log = gLogger.getSubLogger(self.__class__.__name__) + + def getCE(self, queueKey, ceType, ceName, ceParametersDict): + """Return a cached CE for ``queueKey``, (re)building it when needed. + + :param str queueKey: unique identifier of the queue, used as cache key + :param str ceType: CE type passed to the ComputingElementFactory + :param str ceName: CE name passed to the ComputingElementFactory + :param dict ceParametersDict: queue/CE parameters; a change triggers a rebuild + :return: S_OK(ce)/S_ERROR + """ + queueHash = generateQueueHash(ceParametersDict) + cached = self._cache.get(queueKey) + if cached is not None and cached["Hash"] == queueHash: + return S_OK(cached["CE"]) + + # First use, or the queue definition changed: drop any stale cache entry + # (which shuts the old CE down, closing its connection) and build a fresh one. + self.drop(queueKey) + result = self._ceFactory.getCE(ceType=ceType, ceName=ceName, ceParametersDict=ceParametersDict) + if not result["OK"]: + return result + self._cache[queueKey] = {"Hash": queueHash, "CE": result["Value"]} + return S_OK(result["Value"]) + + def drop(self, queueKey): + """Evict a cached CE so it is rebuilt on the next request. + + The evicted CE is shut down so that any underlying connection (e.g. the + SSHComputingElement's SSH/gateway connections) is closed immediately + rather than left to non-deterministic garbage collection, which Fabric + documents as unsafe to rely on. + """ + cached = self._cache.pop(queueKey, None) + if cached is None: + return + try: + cached["CE"].shutdown() + except Exception as e: + self.log.warn("Failed to shut down evicted CE", f"{queueKey}: {e}") + + def matchQueue(jobJDL, queueDict, fullMatch=False): """ Match the job description to the queue definition diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_QueueUtilities.py b/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_QueueUtilities.py index 28db38d959c..d52a21f49b2 100644 --- a/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_QueueUtilities.py +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/test/Test_QueueUtilities.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest -from DIRAC import S_OK +from DIRAC import S_OK, S_ERROR from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import * siteDict1 = { @@ -153,7 +153,7 @@ def test_setPlatform(ceDict, queueDict, dictExpected): ) def test_getQueuesResolved(mocker, queueDict, queuesExpected): """Test the getQueuesResolvedEnhanced function""" - queueCECache = {} + queueCECache = QueueCECache() queueDictLocal = copy.deepcopy(queueDict) ce = MagicMock() @@ -165,3 +165,117 @@ def test_getQueuesResolved(mocker, queueDict, queuesExpected): assert queueDictResolved["OK"] for qName, qDictResolved in queueDictResolved["Value"].items(): assert sorted(qDictResolved) == sorted(queuesExpected[qName]) + + +# Target used to patch the CE factory used internally by QueueCECache. +# The factory is set to return a DIFFERENT CE per build (side_effect=[ce1, ce2, ...]), +# so that *which* CE we get back distinguishes "served from cache" (ce1 again) from +# "rebuilt" (ce2). That, plus the factory call_count, is what proves the cache logic -- +# asserting we get back the value the mock returned would prove nothing. +GET_CE = "DIRAC.Resources.Computing.ComputingElementFactory.ComputingElementFactory.getCE" + + +def test_getQueuesResolved_acceptsLegacyDict(mocker): + """Backward compatibility: a plain dict cache is accepted and adopted as the + backing store, so legacy callers passing ``{}`` keep working (and keep reuse).""" + ce = MagicMock() + ce.isValid = MagicMock(return_value=S_OK()) + ce.ceParameters = {} + mocker.patch(GET_CE, return_value=S_OK(ce)) + + legacyCache = {} + result = getQueuesResolved(copy.deepcopy(siteDict1), legacyCache, instantiateCEs=True) + + assert result["OK"] + # The plain dict was adopted as the cache backing store: entries were written into it. + assert legacyCache + assert all("CE" in entry and "Hash" in entry for entry in legacyCache.values()) + + +def test_QueueCECache_cacheHitDoesNotRebuild(mocker): + """A second call with unchanged parameters reuses the cached CE instead of rebuilding.""" + ce1, ce2 = MagicMock(), MagicMock() + getCEMock = mocker.patch(GET_CE, side_effect=[S_OK(ce1), S_OK(ce2)]) + + cache = QueueCECache() + params = {"CEType": "SSH", "Host": "host1"} + + first = cache.getCE("queue1", "SSH", "ce1", params) + second = cache.getCE("queue1", "SSH", "ce1", params) + + # Factory invoked exactly once, with the forwarded arguments + getCEMock.assert_called_once_with(ceType="SSH", ceName="ce1", ceParametersDict=params) + # Were the cache broken, the 2nd call would rebuild and hand back ce2 instead of ce1 + assert first["Value"] is ce1 + assert second["Value"] is ce1 + + +def test_QueueCECache_parameterChangeRebuilds(mocker): + """Changed parameters rebuild the CE (new hash) and hand back the fresh one.""" + ce1, ce2 = MagicMock(), MagicMock() + getCEMock = mocker.patch(GET_CE, side_effect=[S_OK(ce1), S_OK(ce2)]) + + cache = QueueCECache() + first = cache.getCE("queue1", "SSH", "ce1", {"Host": "host1"}) + second = cache.getCE("queue1", "SSH", "ce1", {"Host": "host2"}) + + assert getCEMock.call_count == 2 # rebuilt because the parameter hash changed + assert first["Value"] is ce1 + assert second["Value"] is ce2 # the new CE, not the stale cached one + + +def test_QueueCECache_dropForcesRebuild(mocker): + """drop() evicts the cached CE, so the next call rebuilds a fresh one.""" + ce1, ce2 = MagicMock(), MagicMock() + mocker.patch(GET_CE, side_effect=[S_OK(ce1), S_OK(ce2)]) + + cache = QueueCECache() + params = {"Host": "host1"} + + first = cache.getCE("queue1", "SSH", "ce1", params) + assert first["Value"] is ce1 + cache.drop("queue1") + + rebuilt = cache.getCE("queue1", "SSH", "ce1", params) + assert rebuilt["Value"] is ce2 # cache miss after drop -> a fresh CE was built + + +def test_QueueCECache_failedBuildIsNotCached(mocker): + """A failed build leaves no cache entry, so a later call retries rather than re-returning the error.""" + ceOK = MagicMock() + getCEMock = mocker.patch(GET_CE, side_effect=[S_ERROR("boom"), S_OK(ceOK)]) + + cache = QueueCECache() + params = {"Host": "host1"} + + failed = cache.getCE("queue1", "SSH", "ce1", params) + assert not failed["OK"] + + retried = cache.getCE("queue1", "SSH", "ce1", params) + assert retried["OK"] + assert retried["Value"] is ceOK + assert getCEMock.call_count == 2 # the failure cached nothing, so the 2nd call rebuilt + + +def test_QueueCECache_dropMissingKeyIsNoOp(): + """drop() on an unknown queue key does nothing and does not raise.""" + cache = QueueCECache() + cache.drop("does-not-exist") # must not raise + + +def test_QueueCECache_dropToleratesCEShutdownFailure(mocker): + """Evicting a CE must never break the cache, even if tearing the CE down fails. + + The agent loop relies on drop()/rebuild always succeeding; a CE that errors + while releasing its connection must not propagate and must still be evicted. + """ + ce = MagicMock() + ce.shutdown.side_effect = RuntimeError("boom") + mocker.patch(GET_CE, return_value=S_OK(ce)) + + cache = QueueCECache() + cache.getCE("queue1", "SSH", "ce1", {"Host": "host1"}) + cache.drop("queue1") # must not raise + + # The entry is gone, so the next call rebuilds rather than serving the dead CE + assert cache.getCE("queue1", "SSH", "ce1", {"Host": "host1"})["OK"] diff --git a/src/DIRAC/WorkloadManagementSystem/scripts/dirac_admin_debug_ce.py b/src/DIRAC/WorkloadManagementSystem/scripts/dirac_admin_debug_ce.py index da8d8e519ee..ff9395fd1b0 100644 --- a/src/DIRAC/WorkloadManagementSystem/scripts/dirac_admin_debug_ce.py +++ b/src/DIRAC/WorkloadManagementSystem/scripts/dirac_admin_debug_ce.py @@ -142,7 +142,6 @@ def buildQueues(vo, sites, ces, ceTypes): result = getQueuesResolved( siteDict=result["Value"], - queueCECache={}, vo=vo, instantiateCEs=True, ) diff --git a/src/DIRAC/WorkloadManagementSystem/scripts/dirac_wms_match.py b/src/DIRAC/WorkloadManagementSystem/scripts/dirac_wms_match.py index 14a4ff5c81c..83126af7823 100644 --- a/src/DIRAC/WorkloadManagementSystem/scripts/dirac_wms_match.py +++ b/src/DIRAC/WorkloadManagementSystem/scripts/dirac_wms_match.py @@ -56,7 +56,7 @@ def main(): gLogger.error("Failed to get CE information") DIRACExit(-1) siteDict = resultQueues["Value"] - result = getQueuesResolved(siteDict, {}, checkPlatform=True) + result = getQueuesResolved(siteDict, checkPlatform=True) if not resultQueues["OK"]: gLogger.error("Failed to get CE information") DIRACExit(-1) From ec7ae5e1261b1aa5c16a4cb3807b73cbfb36ad7d Mon Sep 17 00:00:00 2001 From: aldbr Date: Tue, 9 Jun 2026 17:18:50 +0200 Subject: [PATCH 2/3] chore: fix CECache init docstring --- .../Utilities/QueueUtilities.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py b/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py index 34fc594a902..a165f03ac35 100644 --- a/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py @@ -159,10 +159,12 @@ class QueueCECache: """ def __init__(self, backing=None): - """:param dict backing: optional pre-existing ``queueKey -> {"Hash", "CE"}`` dict - to adopt as the cache store. Used to stay backward compatible with callers - that historically passed (and held onto) a plain dict cache; mutating it in - place preserves their cross-cycle reuse. A fresh dict is used when omitted. + """Initialise the cache, optionally adopting an existing backing store. + + :param dict backing: optional pre-existing ``queueKey -> {"Hash", "CE"}`` dict + to adopt as the cache store. Used to stay backward compatible with callers + that historically passed (and held onto) a plain dict cache; mutating it in + place preserves their cross-cycle reuse. A fresh dict is used when omitted. """ # queueKey -> {"Hash": , "CE": } self._cache = backing if backing is not None else {} From fc1f85bf115e08c7c484a6f58710a8109d7802da Mon Sep 17 00:00:00 2001 From: aldbr Date: Wed, 10 Jun 2026 11:09:16 +0200 Subject: [PATCH 3/3] chore: narrow exceptions in SSHCE --- src/DIRAC/Resources/Computing/SSHComputingElement.py | 4 ++-- .../WorkloadManagementSystem/Utilities/QueueUtilities.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/DIRAC/Resources/Computing/SSHComputingElement.py b/src/DIRAC/Resources/Computing/SSHComputingElement.py index a2e9e3c7fb0..98ca6578f7f 100644 --- a/src/DIRAC/Resources/Computing/SSHComputingElement.py +++ b/src/DIRAC/Resources/Computing/SSHComputingElement.py @@ -528,13 +528,13 @@ def _closeConnection(self, connection): gateway = getattr(connection, "gateway", None) try: connection.close() - except Exception as e: # a close failure must not break cache eviction + except (OSError, EOFError, SSHException) as e: # a close failure must not break cache eviction self.log.warn("Failed to close SSH connection", str(e)) # The jump host (SSHTunnel) is a separate Connection that close() ignores if isinstance(gateway, Connection): try: gateway.close() - except Exception as e: + except (OSError, EOFError, SSHException) as e: self.log.warn("Failed to close SSH gateway connection", str(e)) def shutdown(self): diff --git a/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py b/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py index a165f03ac35..00b6c6e5a21 100644 --- a/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py +++ b/src/DIRAC/WorkloadManagementSystem/Utilities/QueueUtilities.py @@ -207,6 +207,7 @@ def drop(self, queueKey): return try: cached["CE"].shutdown() + # CE.shutdown() is polymorphic across CE types; eviction must never fail except Exception as e: self.log.warn("Failed to shut down evicted CE", f"{queueKey}: {e}")