Source code for pclean.parallel.cluster

"""Dask cluster lifecycle management.

Supports:

* Starting a ``LocalCluster`` (default)
* Submitting workers as SLURM batch jobs via ``dask_jobqueue.SLURMCluster``
* Connecting to an existing ``distributed.Client`` via scheduler address
* Graceful shutdown with image cleanup
"""

from __future__ import annotations

import logging
import os
import re
from pprint import pformat

log = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# Lazy imports so the module is importable without dask installed.
# ---------------------------------------------------------------------------

_dask_distributed = None
QUEUE_WAIT = 60
COMM_TIMEOUT = 600  # client wait timeout threshold to initialize connection to the scheduler, in seconds

# ---------------------------------------------------------------------------
# Monkey-patch: guard against garbage TCP frame sizes
# ---------------------------------------------------------------------------
# When Dask workers all reconnect simultaneously after a long blocking C++ call
# (e.g. CASA makePSF), the OS TCP stack sometimes delivers a new connection
# carrying stale residual bytes.  Dask reads those bytes as an 8-byte uint64
# frame-length header and passes the absurd value to numpy.empty(), triggering
# a MemoryError logged as a tornado ERROR.  The run still completes, but the
# log is alarming.
#
# The fix wraps distributed.comm.tcp.read_bytes_rw AND the numpy host_array
# allocator.  Both patches are needed because:
#
#   1. read_bytes_rw:  TCP.read() resolves this name from the tcp module
#      globals at every call, so replacing the attribute works.
#
#   2. host_array:  In some Dask versions / code paths the original
#      read_bytes_rw body is reached directly (e.g. via closures or
#      inlined buffers).  Patching host_array provides a second safety
#      net so that numpy.empty(7.27 EiB) never happens.
#
# We also must be careful about *how* we raise StreamClosedError.
# Dask's convert_stream_closed_error() inspects exc.real_error:
#
#     if hasattr(exc, "real_error"):
#         if exc.real_error is None:
#             raise CommClosedError(...) from exc
#         exc = exc.real_error           # <-- replaces exc !
#     ...
#     raise CommClosedError(...) from exc
#
# If we pass a *string* as real_error, ``exc`` becomes that string and
# ``raise ... from <string>`` fails with:
#     TypeError: exception causes must derive from BaseException
#
# Fix: always pass ``real_error=None`` so the first branch fires.

_MAX_FRAME_BYTES = 1 << 30  # 1 GiB — no real Dask message exceeds this
_dask_tcp_patched = False


def _patch_dask_tcp() -> None:
    """Monkey-patch distributed.comm.tcp to reject implausibly large frames."""
    global _dask_tcp_patched
    if _dask_tcp_patched:
        return
    try:
        import distributed.comm.tcp as _tcp
        import distributed.protocol.utils as _proto_utils
        from tornado.iostream import StreamClosedError as _SCE

        # --- Patch 1: read_bytes_rw (primary guard) ----------------------
        _orig_read = _tcp.read_bytes_rw

        async def _safe_read_bytes_rw(stream, n: int):
            if n > _MAX_FRAME_BYTES:
                log.debug(
                    "Rejected implausible TCP frame size %s B "
                    "(> %s B limit) — stale/reset connection",
                    f"{n:,}",
                    f"{_MAX_FRAME_BYTES:,}",
                )
                # real_error=None → convert_stream_closed_error takes the
                # ``exc.real_error is None`` branch and chains correctly.
                raise _SCE(real_error=None)
            return await _orig_read(stream, n)

        _tcp.read_bytes_rw = _safe_read_bytes_rw

        # --- Patch 2: host_array (secondary guard) -----------------------
        # Catches the case where the original read_bytes_rw body is reached
        # without going through our wrapper (e.g. closure / import caching).
        _orig_host_array = _proto_utils.host_array

        def _safe_host_array(n):
            if n > _MAX_FRAME_BYTES:
                log.debug(
                    "host_array: rejected %s B allocation "
                    "(> %s B limit)",
                    f"{n:,}",
                    f"{_MAX_FRAME_BYTES:,}",
                )
                raise MemoryError(
                    f"Implausible allocation request {n:,} B "
                    f"(> {_MAX_FRAME_BYTES:,} B limit) — "
                    "likely garbage TCP frame from stale connection"
                )
            return _orig_host_array(n)

        _proto_utils.host_array = _safe_host_array
        # Also replace the reference imported into tcp.py (if any)
        if hasattr(_tcp, 'host_array'):
            _tcp.host_array = _safe_host_array

        _dask_tcp_patched = True
        log.debug("Patched distributed.comm.tcp.read_bytes_rw + host_array with frame-size guard")
    except Exception as e:  # noqa: BLE001
        log.warning(
            "Could not patch distributed.comm.tcp.read_bytes_rw and "
            "distributed.protocol.utils.host_array: %s",
            e,
        )


def _dd():
    global _dask_distributed
    if _dask_distributed is None:
        import dask.distributed as dd
        import dask.config

        _patch_dask_tcp()

        #  Minimize the risk of a dask worker inside a blocking C++ call (via a Python binding like casatools)
        #  being marked as unresponsive and killed by the scheduler.  CASA workloads can have long-running C++ calls that
        #  Dask cannot introspect.
        _comm_timeout_str = f'{COMM_TIMEOUT}s'
        dask.config.set({
            # Per-message connection timeout during operation
            'distributed.comm.timeouts.connect': _comm_timeout_str,  # default: 10s
            # Per-message read/write timeout during operation
            'distributed.comm.timeouts.tcp': _comm_timeout_str,  # default: 30s
            'distributed.worker.heartbeat-interval': '10s',  # default: 0.5s
            'distributed.scheduler.worker-ttl': '1200s',  # 1200s (20 minutes) or None to disable
            'distributed.worker.lifetime.duration': None,  # no forced worker restart
            # Retry failed connections (guards against garbage TCP frames from OS-level
            # half-open connections when many workers reconnect simultaneously)
            'distributed.comm.retry.count': 5,  # default: 0
            'distributed.comm.retry.delay.min': '1s',  # default: 1s
            'distributed.comm.retry.delay.max': '30s',  # default: 30s
        })
        _dask_distributed = dd
    return _dask_distributed


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


[docs] class DaskClusterManager: """Thin wrapper that owns a ``dask.distributed.Client``. Supports three cluster backends selected by *cluster_type*: * ``'local'`` — spin up a ``dask.distributed.LocalCluster`` (default). * ``'slurm'`` — submit workers as SLURM batch jobs via ``dask_jobqueue.SLURMCluster``. Requires the optional ``dask-jobqueue`` package (``pip install dask-jobqueue``). * ``'address'`` — connect to a pre-existing scheduler at *scheduler_address*. For backward compatibility, if *scheduler_address* is set and *cluster_type* is left at ``'local'``, the manager silently switches to ``'address'`` mode. Args: nworkers: Number of workers. ``None`` -> ``os.cpu_count()``. scheduler_address: Scheduler URL for ``'address'`` mode. threads_per_worker: Threads per Dask worker (default 1 -- CASA tools are not thread-safe). memory_limit: Per-worker memory limit. Default ``'0'`` disables Dask's memory management, which is correct for CASA workloads because all heavy allocations happen inside C++ casatools (reported as "unmanaged memory"). Dask cannot free this memory, so its pause/spill heuristics only cause workers to stall. Concurrency is bounded by ``as_completed`` instead. local_directory: Scratch directory for Dask spill-to-disk. cluster_type: ``'local'``, ``'slurm'``, or ``'address'``. slurm_queue: SLURM partition name (``--partition``). slurm_account: SLURM account string (``--account``). slurm_walltime: Per-job wall time (``--time``). slurm_job_mem: Per-job memory (``--mem``). slurm_cores_per_job: CPUs per SLURM job (``--cpus-per-task``). slurm_job_name: SLURM job name (``--job-name``). Appears in ``squeue`` output under the NAME column, making workers easy to identify. Defaults to ``None`` (dask-jobqueue uses ``'dask-worker'``). slurm_job_extra_directives: Extra ``#SBATCH`` lines. slurm_python: Path to the Python executable on compute nodes. slurm_local_directory: Worker scratch directory on compute nodes. slurm_log_directory: Directory for SLURM stdout/stderr logs. slurm_job_script_prologue: Shell commands injected before the worker process starts (e.g. ``module load`` or ``conda activate``). """ def __init__( self, nworkers: int | None = None, scheduler_address: str | None = None, threads_per_worker: int = 1, memory_limit: str = '0', local_directory: str | None = None, cluster_type: str = 'local', slurm_queue: str | None = None, slurm_account: str | None = None, slurm_walltime: str = '24:00:00', slurm_job_mem: str = '20GB', slurm_cores_per_job: int = 1, slurm_job_name: str | None = None, slurm_job_extra_directives: list[str] | None = None, slurm_python: str | None = None, slurm_local_directory: str | None = None, slurm_log_directory: str = 'logs', slurm_job_script_prologue: list[str] | None = None, ): self.nworkers = nworkers or os.cpu_count() or 4 self.scheduler_address = scheduler_address self.threads_per_worker = threads_per_worker self.memory_limit = memory_limit self.local_directory = local_directory # Backward compat: scheduler_address implies 'address' mode. if scheduler_address and cluster_type == 'local': cluster_type = 'address' self.cluster_type = cluster_type # SLURM-specific self.slurm_queue = slurm_queue self.slurm_account = slurm_account self.slurm_walltime = slurm_walltime self.slurm_job_mem = slurm_job_mem self.slurm_cores_per_job = slurm_cores_per_job self.slurm_job_name = slurm_job_name self.slurm_job_extra_directives = slurm_job_extra_directives or [] self.slurm_python = slurm_python self.slurm_local_directory = slurm_local_directory self.slurm_log_directory = slurm_log_directory self.slurm_job_script_prologue = slurm_job_script_prologue or [] self._cluster = None self._client: object | None = None # ------------------------------------------------------------------ # Lifecycle # ------------------------------------------------------------------
[docs] def start(self) -> 'DaskClusterManager': """Start (or connect to) the Dask cluster and return *self*.""" dd = _dd() if self.threads_per_worker > 1: log.warning( 'threads_per_worker=%d: casatools are NOT thread-safe; ' 'use threads_per_worker=1 (default) for correctness', self.threads_per_worker, ) if self.cluster_type == 'address': log.info('Connecting to existing scheduler at %s', self.scheduler_address) self._client = dd.Client(self.scheduler_address, timeout=COMM_TIMEOUT) elif self.cluster_type == 'slurm': self._start_slurm(dd) else: # 'local' log.info('Starting LocalCluster with %d workers', self.nworkers) self._cluster = dd.LocalCluster( n_workers=self.nworkers, processes=True, threads_per_worker=self.threads_per_worker, memory_limit=self.memory_limit, local_directory=self.local_directory, ) self._client = dd.Client(self._cluster, timeout=COMM_TIMEOUT) # Block until all requested workers have registered with the # scheduler. Without this, worker_count can return a smaller # number than nworkers due to a startup race condition. self._client.wait_for_workers(self.nworkers, timeout=QUEUE_WAIT) # Verify the cluster actually created the requested workers. # Use nthreads() for a fresh, synchronous query to the scheduler # (scheduler_info() can return a stale cached snapshot). actual = len(self._client.nthreads()) if actual != self.nworkers: log.warning( 'Requested %d workers but LocalCluster only created %d ' '(system may lack resources). Adjusting nworkers.', self.nworkers, actual, ) self.nworkers = actual log.info('Dask cluster ready: %d workers registered', self.worker_count) log.info('Dask dashboard: %s', self._client.dashboard_link) log.info(' client: %s', self._client) log.info(' cluster: %s', self._client.cluster) def get_status(dask_worker) -> tuple[str, str]: return dask_worker.status, dask_worker.id status: dict[str, tuple[str, str]] = self._client.run(get_status) if status: log.info('worker status: \n %s', pformat(status)) return self
def _start_slurm(self, dd) -> None: """Create a ``dask_jobqueue.SLURMCluster`` and scale to *nworkers* jobs.""" try: from dask_jobqueue import SLURMCluster from dask_jobqueue.slurm import SLURMJob except ImportError as exc: raise ImportError( "cluster_type='slurm' requires dask-jobqueue: " 'pip install dask-jobqueue' ) from exc log.info( 'Starting SLURMCluster (queue=%s, nworkers=%d, mem=%s, walltime=%s)', self.slurm_queue, self.nworkers, self.slurm_job_mem, self.slurm_walltime, ) slurm_kwargs: dict = dict( queue=self.slurm_queue, account=self.slurm_account, walltime=self.slurm_walltime, cores=self.slurm_cores_per_job, memory=self.slurm_job_mem, processes=1, # one Dask worker per SLURM job local_directory=self.slurm_local_directory or self.local_directory, log_directory=self.slurm_log_directory, job_extra_directives=self.slurm_job_extra_directives, ) if self.slurm_job_name is not None: slurm_kwargs['job_name'] = self.slurm_job_name # dask-jobqueue uses the same job_name for every sbatch job. # Subclass SLURMJob to append a sequential index so each worker # is distinguishable in ``squeue`` output (e.g. pclean-w-0, pclean-w-1). # import itertools # counter = itertools.count() class _NumberedSLURMJob(SLURMJob): def __init__(self, *args, **kwargs): jn = kwargs.get('job_name') if jn is not None: # kwargs['job_name'] = f'{jn}-worker{next(counter)}' kwargs['job_name'] = f'{jn}' super().__init__(*args, **kwargs) # Merge stderr into the same file as stdout so each # worker produces a single .out log instead of # separate .out and .err files. if self.job_header is not None: self.job_header = re.sub( r'(#SBATCH\s+-e\s+\S+)\.err', r'\1.out', self.job_header, ) slurm_kwargs['job_cls'] = _NumberedSLURMJob if self.slurm_python: slurm_kwargs['python'] = self.slurm_python if self.slurm_job_script_prologue: slurm_kwargs['job_script_prologue'] = self.slurm_job_script_prologue self._cluster = SLURMCluster(**slurm_kwargs) self._cluster.scale(jobs=self.nworkers) self._client = dd.Client(self._cluster, timeout=COMM_TIMEOUT)
[docs] def shutdown(self) -> None: """Close client and cluster.""" if self._client is not None: self._client.close() self._client = None if self._cluster is not None: self._cluster.close() self._cluster = None
# ------------------------------------------------------------------ # Accessors # ------------------------------------------------------------------ @property def client(self): """Return the ``dask.distributed.Client``.""" if self._client is None: raise RuntimeError('Cluster not started — call .start() first') return self._client @property def worker_count(self) -> int: """Number of workers currently registered with the scheduler. Uses ``client.nthreads()`` which is a direct synchronous query to the scheduler, avoiding stale cached snapshots from ``scheduler_info()``. Note that this can be less than the requested nworkers due to resource constraints or startup issues. The cluster manager will log a warning and adjust nworkers accordingly. """ return len(self.client.nthreads()) # ------------------------------------------------------------------ # Context-manager protocol # ------------------------------------------------------------------ def __enter__(self): self.start() return self def __exit__(self, *exc): self.shutdown() return False