Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/DIRAC/Resources/Computing/SSHBatchComputingElement.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,16 @@ def killJob(self, jobIDs):

return result

def shutdown(self):
"""Close every per-host SSH connection (and gateway), releasing their
Paramiko Transport threads. Called when the CE is evicted/rebuilt by the
:class:`~DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities.QueueCECache`.
"""
for details in self.connections.values():
self._closeConnection(details.get("connection"))
self.connections = {}
return S_OK()

#############################################################################

def getCEStatus(self):
Expand Down
32 changes: 32 additions & 0 deletions src/DIRAC/Resources/Computing/SSHComputingElement.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,38 @@ def killJob(self, jobIDList):
jobIDList = [jobIDList]
return self._killJobOnHost(self.connection, jobIDList)

def _closeConnection(self, connection):
"""Close an SSH connection *and its gateway* (if any).

``fabric.Connection.close()`` closes the connection's own client and SFTP
sessions but NOT a gateway ``Connection``, which is opened separately as a
jump host. The gateway must therefore be closed explicitly, otherwise its
Paramiko ``Transport`` thread (and the resources it pins) leaks. This is a
no-op on a connection that was never opened.
"""
if connection is None:
return
gateway = getattr(connection, "gateway", None)
try:
connection.close()
except (OSError, EOFError, SSHException) as e: # a close failure must not break cache eviction
self.log.warn("Failed to close SSH connection", str(e))
# The jump host (SSHTunnel) is a separate Connection that close() ignores
if isinstance(gateway, Connection):
try:
gateway.close()
except (OSError, EOFError, SSHException) as e:
self.log.warn("Failed to close SSH gateway connection", str(e))

def shutdown(self):
Comment thread
aldbr marked this conversation as resolved.
"""Close the SSH connection (and its gateway), releasing the Paramiko
Transport thread. Called when the CE is evicted/rebuilt by the
:class:`~DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities.QueueCECache`.
"""
self._closeConnection(self.connection)
self.connection = None
return S_OK()

def _killJobOnHost(self, connection: Connection, jobIDList: list[str]):
"""Kill the jobs for the given list of job IDs"""
batchSystemJobList = []
Expand Down
47 changes: 33 additions & 14 deletions src/DIRAC/WorkloadManagementSystem/Agent/PilotStatusAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
"""
import datetime

from DIRAC import S_OK, gConfig
from DIRAC import S_OK
from DIRAC.AccountingSystem.Client.DataStoreClient import gDataStoreClient
from DIRAC.AccountingSystem.Client.Types.Pilot import Pilot as PilotAccounting
from DIRAC.ConfigurationSystem.Client.Helpers.Resources import getCESiteMapping
from DIRAC.Core.Base.AgentModule import AgentModule
from DIRAC.Core.Utilities import TimeUtilities
from DIRAC.Interfaces.API.DiracAdmin import DiracAdmin
from DIRAC.WorkloadManagementSystem.Client import PilotStatus
from DIRAC.WorkloadManagementSystem.Client.PilotManagerClient import PilotManagerClient
from DIRAC.WorkloadManagementSystem.DB.JobDB import JobDB
from DIRAC.WorkloadManagementSystem.DB.PilotAgentsDB import PilotAgentsDB
from DIRAC.WorkloadManagementSystem.Service.WMSUtilities import killPilotsInQueues
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import QueueCECache


class PilotStatusAgent(AgentModule):
Expand All @@ -39,14 +40,15 @@ def __init__(self, *args, **kwargs):

self.jobDB = None
self.pilotDB = None
self.diracadmin = None
# Cache of ComputingElement instances keyed by queue.
self.ceCache = None

#############################################################################
def initialize(self):
"""Sets defaults"""

self.pilotDB = PilotAgentsDB()
self.diracadmin = DiracAdmin()
self.ceCache = QueueCECache()
self.jobDB = JobDB()
self.clearPilotsDelay = self.am_getOption("ClearPilotsDelay", 30)
self.clearAbortedDelay = self.am_getOption("ClearAbortedPilotsDelay", 7)
Expand Down Expand Up @@ -200,16 +202,33 @@ def _addPilotsAccountingReport(self, pilotsData):
return S_OK()

def _killPilots(self, acc):
for i in sorted(acc.keys()):
result = self.diracadmin.getPilotInfo(i)
if result["OK"] and i in result["Value"] and "Status" in result["Value"][i]:
ret = self.diracadmin.killPilot(str(i))
if ret["OK"]:
self.log.info("Successfully deleted", f": {i} (Status : {result['Value'][i]['Status']})")
else:
self.log.error("Failed to delete pilot: ", f"{i} : {ret['Message']}")
else:
self.log.error("Failed to get pilot info", f"{i} : {str(result)}")
"""Declare the given pilots killed on their CEs.

Pilots are grouped per queue and killed with a single call per queue,
reusing a cached CE/connection per queue across cycles (see
:func:`~DIRAC.WorkloadManagementSystem.Service.WMSUtilities.killPilotsInQueues`).
"""
# Group the pilots to kill per queue
pilotsByQueue = {}
for pRef in acc:
pilotDict = acc[pRef]
queueFields = [pilotDict["VO"], pilotDict["GridSite"], pilotDict["DestinationSite"], pilotDict["Queue"]]
# A pilot with an incomplete queue definition cannot be located on a
# CE; skip it rather than letting it abort the whole batch.
if not all(queueFields):
self.log.warn("Cannot determine queue for pilot, skipping kill", f"{pRef} : {queueFields}")
continue
queueKey = "@@@".join(queueFields)
queueData = pilotsByQueue.setdefault(queueKey, {"GridType": pilotDict["GridType"], "PilotList": []})
queueData["PilotList"].append(pRef)

if not pilotsByQueue:
return

# The CEs (and their connections) are reused across cycles via self.ceCache.
result = killPilotsInQueues(pilotsByQueue, ceCache=self.ceCache)
if not result["OK"]:
self.log.error("Failed to kill some pilots", result["Message"])

def _checkJobLastUpdateTime(self, joblist, StalledDays):
timeLimitToConsider = datetime.datetime.utcnow() - TimeUtilities.day * StalledDays
Expand Down
4 changes: 2 additions & 2 deletions src/DIRAC/WorkloadManagementSystem/Agent/PushJobAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
transferInputSandbox,
)
from DIRAC.WorkloadManagementSystem.private.ConfigHelper import findGenericPilotCredentials
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved, QueueCECache
from DIRAC.WorkloadManagementSystem.Utilities.Utils import createJobWrapper

MAX_JOBS_MANAGED = 100
Expand All @@ -61,7 +61,7 @@ def __init__(self, agentName, loadName, baseAgentName=False, properties=None):
self.firstPass = True
self.maxJobsToSubmit = MAX_JOBS_MANAGED
self.queueDict = {}
self.queueCECache = {}
self.queueCECache = QueueCECache()

self.pilotDN = ""
self.vo = ""
Expand Down
6 changes: 3 additions & 3 deletions src/DIRAC/WorkloadManagementSystem/Agent/SiteDirector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
getPilotFilesCompressedEncodedDict,
pilotWrapperScript,
)
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import getQueuesResolved, QueueCECache

MAX_PILOTS_TO_SUBMIT = 100

Expand All @@ -60,7 +60,7 @@ def __init__(self, *args, **kwargs):

self.queueDict = {}
# self.queueCECache aims at saving CEs information over the cycles to avoid to create the exact same CEs each cycle
self.queueCECache = {}
self.queueCECache = QueueCECache()
self.failedQueues = defaultdict(int)
self.maxPilotsToSubmit = MAX_PILOTS_TO_SUBMIT

Expand Down Expand Up @@ -575,7 +575,7 @@ def _getExecutable(self, queue: str, proxy: X509Chain, jobExecDir: str = "", env
# in your machine, the executable files will be in the same place
# but it does not matter since they are very temporary

ce = self.queueCECache[queue]["CE"]
ce = self.queueDict[queue]["CE"]
workingDirectory = getattr(ce, "workingDirectory", self.workingDirectory)

executable = self._writePilotScript(
Expand Down
51 changes: 35 additions & 16 deletions src/DIRAC/WorkloadManagementSystem/Service/WMSUtilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from DIRAC.FrameworkSystem.Client.TokenManagerClient import gTokenManager
from DIRAC.Resources.Computing.ComputingElementFactory import ComputingElementFactory
from DIRAC.WorkloadManagementSystem.Client.PilotScopes import PILOT_SCOPES
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import QueueCECache

# List of files to be inserted/retrieved into/from pilot Output Sandbox
# first will be defined as StdOut in JDL and the second as StdErr
Expand Down Expand Up @@ -101,34 +102,52 @@ def getPilotRef(pilotReference, pilotDict):
return S_OK(pRef)


def killPilotsInQueues(pilotRefDict):
"""kill pilots queue by queue
def killPilotsInQueues(pilotRefDict, ceCache=None):
"""Kill pilots queue by queue.

:params dict pilotRefDict: a dict of pilots in queues
"""
Pilots are grouped per queue in ``pilotRefDict`` (key
``"<vo>@@@<site>@@@<ce>@@@<queue>"``) and killed with a single call per queue.
Every queue is attempted: failures are collected and reported together, so a
single unreachable CE no longer prevents pilots on the other queues from being
killed (previous versions failed fast on the first error).

ceFactory = ComputingElementFactory()
:param dict pilotRefDict: ``{queueKey: {"GridType": str, "PilotList": [ref, ...]}}``
:param ceCache: optional :class:`~DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities.QueueCECache`
whose CEs (and their connections) are reused across calls. When omitted, a
transient cache is created for this call only.
:return: S_OK() if all queues succeeded, otherwise S_ERROR aggregating the failures
"""
if ceCache is None:
ceCache = QueueCECache()

failures = {}
for key, pilotDict in pilotRefDict.items():
vo, site, ce, queue = key.split("@@@")

result = getQueue(site, ce, queue)
if not result["OK"]:
return result
queueDict = result["Value"]
gridType = pilotDict["GridType"]
result = ceFactory.getCE(gridType, ce, queueDict)
failures[key] = result["Message"]
continue

result = ceCache.getCE(key, pilotDict["GridType"], ce, result["Value"])
if not result["OK"]:
return result
ce = result["Value"]
failures[key] = result["Message"]
continue
computingElement = result["Value"]

pilotDict["VO"] = vo
result = setPilotCredentials(ce, pilotDict)
result = setPilotCredentials(computingElement, pilotDict)
if not result["OK"]:
return result
failures[key] = result["Message"]
continue

pilotList = pilotDict["PilotList"]
result = ce.killJob(pilotList)
result = computingElement.killJob(pilotDict["PilotList"])
if not result["OK"]:
return result
# Drop the (possibly stale) CE so it is rebuilt on the next call
ceCache.drop(key)
failures[key] = result["Message"]
continue

if failures:
return S_ERROR(f"Failed to kill pilots in queues: {failures}")
return S_OK()
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
""" Test class for WMSUtilities
"""
from unittest.mock import MagicMock

from DIRAC import S_OK, S_ERROR
from DIRAC.WorkloadManagementSystem.Service.WMSUtilities import killPilotsInQueues
from DIRAC.WorkloadManagementSystem.Utilities.QueueUtilities import QueueCECache

# The CE factory is reached through QueueCECache; getQueue and setPilotCredentials
# are module-level names looked up inside killPilotsInQueues, so we patch them there.
GET_CE = "DIRAC.Resources.Computing.ComputingElementFactory.ComputingElementFactory.getCE"
GET_QUEUE = "DIRAC.WorkloadManagementSystem.Service.WMSUtilities.getQueue"
SET_CREDS = "DIRAC.WorkloadManagementSystem.Service.WMSUtilities.setPilotCredentials"


def _twoQueues():
return {
"vo@@@site1@@@ce1@@@queue1": {"GridType": "ssh", "PilotList": ["p1", "p2"]},
"vo@@@site2@@@ce2@@@queue2": {"GridType": "ssh", "PilotList": ["p3"]},
}


def test_killPilotsInQueues_allSucceed(mocker):
"""One killJob call per queue when everything succeeds."""
mocker.patch(GET_QUEUE, return_value=S_OK({"CEType": "ssh"}))
mocker.patch(SET_CREDS, return_value=S_OK())
ce = MagicMock()
ce.killJob = MagicMock(return_value=S_OK())
mocker.patch(GET_CE, return_value=S_OK(ce))

result = killPilotsInQueues(_twoQueues())

assert result["OK"]
assert ce.killJob.call_count == 2 # one call per queue


def test_killPilotsInQueues_attemptsAllAndAggregates(mocker):
"""A queue whose killJob fails must not stop the others from being attempted."""
mocker.patch(GET_QUEUE, return_value=S_OK({"CEType": "ssh"}))
mocker.patch(SET_CREDS, return_value=S_OK())
ce = MagicMock()
# First queue fails to kill, second succeeds (dict preserves insertion order)
ce.killJob = MagicMock(side_effect=[S_ERROR("boom"), S_OK()])
mocker.patch(GET_CE, return_value=S_OK(ce))

result = killPilotsInQueues(_twoQueues())

assert not result["OK"] # the failure is reported...
assert ce.killJob.call_count == 2 # ...but every queue was still attempted


def test_killPilotsInQueues_reusesProvidedCache(mocker):
"""A provided QueueCECache reuses the CE across calls instead of rebuilding it."""
mocker.patch(GET_QUEUE, return_value=S_OK({"CEType": "ssh"}))
mocker.patch(SET_CREDS, return_value=S_OK())
ce = MagicMock()
ce.killJob = MagicMock(return_value=S_OK())
getCEMock = mocker.patch(GET_CE, return_value=S_OK(ce))

cache = QueueCECache()
refDict = {"vo@@@site1@@@ce1@@@queue1": {"GridType": "ssh", "PilotList": ["p1"]}}

assert killPilotsInQueues(refDict, ceCache=cache)["OK"]
assert killPilotsInQueues(refDict, ceCache=cache)["OK"]

# Same queue, unchanged parameters -> the CE is built once and reused on the 2nd call
getCEMock.assert_called_once()
Loading
Loading