Source code for pclean.parallel.continuum_parallel
"""Parallel continuum (MFS) imaging engine.
Distributes *visibility rows* across Dask workers. Each worker runs
its own ``synthesisimager`` on a data chunk to produce a partial image.
The coordinator then uses ``synthesisnormalizer`` to **gather** partial
images, normalize, run the (serial) minor cycle, and **scatter** the
updated model back to workers for the next major cycle.
Parallelism pattern::
Major cycle (gridding / degridding) -- parallel across row chunks
Minor cycle (deconvolution) -- serial on the gathered full image
"""
from __future__ import annotations
import logging
import os
import shutil
from typing import TYPE_CHECKING
from pclean.imaging.deconvolver import Deconvolver
from pclean.imaging.normalizer import Normalizer
from pclean.parallel.cluster import DaskClusterManager
from pclean.parallel.worker_tasks import _WorkerGridder
from pclean.utils.partition import partition_continuum
if TYPE_CHECKING:
from pclean.config import PcleanConfig
log = logging.getLogger(__name__)
_casatools = None
def _ct():
global _casatools
if _casatools is None:
import casatools as ct
_casatools = ct
return _casatools
[docs]
class ParallelContinuumImager:
"""Row-parallel continuum (MFS) CLEAN imager.
Args:
config: Full imaging configuration (specmode should be ``'mfs'``).
cluster: Running Dask cluster.
"""
def __init__(self, config: PcleanConfig, cluster: DaskClusterManager):
self.config = config
self.cluster = cluster
self._part_bundles: list[dict] = []
self._actors: list = []
self._normalizer: Normalizer | None = None
self._deconvolver: Deconvolver | None = None
self._ib_tool = None
self._major_count = 0
# ------------------------------------------------------------------
# Public
# ------------------------------------------------------------------
[docs]
def run(self) -> dict:
"""Execute the full parallel continuum pipeline.
Returns:
Convergence summary.
"""
try:
self._partition_data()
self._create_actors()
self._setup_normalizer()
self._setup_deconvolver()
self._setup_iteration_control()
# PSF
self._parallel_make_psf()
self._normalizer.normalize_psf()
# PB
self._parallel_make_pb()
self._normalizer.normalize_pb()
# Initial residual
if self.config.misc.calcres:
self._parallel_major_cycle(is_first=True)
self._normalizer.post_major_mfs()
# Major / minor loop
if self.config.niter > 0:
converged = self._check_convergence()
while not converged:
self._deconvolver.setup_mask()
did = self._run_minor_cycle()
if did:
self._normalizer.pre_major_mfs()
self._parallel_major_cycle()
self._normalizer.post_major_mfs()
converged = self._check_convergence() or (not did)
if self.config.deconvolution.restoration:
self._deconvolver.restore()
if self.config.deconvolution.pbcor:
self._deconvolver.pbcor()
# Clean up partial images unless keep_partimages is set
if not self.config.cluster.keep_partimages:
self._cleanup_partimages()
return self._summary()
finally:
self._teardown()
# ------------------------------------------------------------------
# Private — partitioning & actor management
# ------------------------------------------------------------------
def _partition_data(self) -> None:
nworkers = self.cluster.nworkers
self._part_bundles = partition_continuum(self.config, nworkers)
log.info('Continuum imaging: %d row-chunks on %d workers', len(self._part_bundles), nworkers)
def _create_actors(self) -> None:
"""Create persistent ``_WorkerGridder`` actors on each worker."""
client = self.cluster.client
self._actors = []
for bundle in self._part_bundles:
actor_future = client.submit(
_WorkerGridder,
bundle,
actor=True,
)
self._actors.append(actor_future.result()) # blocks until ready
def _teardown(self) -> None:
for actor in self._actors:
try:
actor.done().result()
except Exception:
pass
self._actors.clear()
if self._normalizer is not None:
self._normalizer.teardown()
if self._deconvolver is not None:
self._deconvolver.teardown()
if self._ib_tool is not None:
self._ib_tool.done()
self._ib_tool = None
# ------------------------------------------------------------------
# Private — normalizer & deconvolver on coordinator
# ------------------------------------------------------------------
def _setup_normalizer(self) -> None:
partimagenames = [b['allimpars']['0']['imagename'] for b in self._part_bundles]
normpars = dict(self.config.to_casa_normpars()['0'])
normpars['partimagenames'] = partimagenames
self._normalizer = Normalizer(normpars, partimagenames)
self._normalizer.setup()
def _setup_deconvolver(self) -> None:
self._deconvolver = Deconvolver(
imagename=self.config.imagename,
decpars=dict(self.config.to_casa_decpars()['0']),
)
self._deconvolver.setup()
def _setup_iteration_control(self) -> None:
ct = _ct()
self._ib_tool = ct.iterbotsink()
self._ib_tool.setupiteration(iterpars=self.config.to_casa_iterpars())
# ------------------------------------------------------------------
# Private — parallel major-cycle operations
# ------------------------------------------------------------------
def _parallel_make_psf(self) -> None:
log.info('Computing PSF (parallel) …')
futures = [a.make_psf() for a in self._actors]
_wait_all(futures)
def _parallel_make_pb(self) -> None:
log.info('Computing PB (parallel) …')
futures = [a.make_pb() for a in self._actors]
_wait_all(futures)
def _parallel_major_cycle(self, is_first: bool = False) -> None:
log.info('Major cycle %d (parallel) …', self._major_count)
last = False
if self._ib_tool is not None and not is_first:
last = self._ib_tool.cleanComplete(lastcyclecheck=True)
controls = {'lastcycle': last}
futures = [a.execute_major_cycle(controls) for a in self._actors]
_wait_all(futures)
self._major_count += 1
if self._ib_tool is not None:
self._ib_tool.endmajorcycle()
# ------------------------------------------------------------------
# Private — serial minor cycle on coordinator
# ------------------------------------------------------------------
def _run_minor_cycle(self) -> bool:
iterbotrec = self._ib_tool.getminorcyclecontrols()
exrec = self._deconvolver.execute_minor(iterbotrec)
self._ib_tool.mergeexecrecord(exrec, 0)
return exrec.get('iterdone', 0) > 0
def _check_convergence(self) -> bool:
self._ib_tool.resetminorcycleinfo()
initrec = self._deconvolver.init_minor()
self._ib_tool.mergeinitrecord(initrec)
nmajor = self.config.iteration.nmajor
reached = nmajor > 0 and self._major_count >= nmajor
return self._ib_tool.cleanComplete(reachedMajorLimit=reached)
# ------------------------------------------------------------------
# Private — partial-image cleanup
# ------------------------------------------------------------------
def _cleanup_partimages(self) -> None:
"""Remove intermediate per-worker partial images.
Uses a glob on each partial-image prefix so that both
single-term (``.psf``) and multi-term (``.psf.tt0``,
``.weight.tt2``, ...) products are found and removed.
"""
import glob
removed = 0
for bundle in self._part_bundles:
abs_name = os.path.abspath(bundle['allimpars']['0']['imagename'])
for path in glob.glob(f'{abs_name}.*'):
if os.path.isdir(path):
shutil.rmtree(path, ignore_errors=True)
removed += 1
elif os.path.isfile(path):
os.remove(path)
removed += 1
if removed:
log.info('Cleaned up %d partial-image artifacts', removed)
# ------------------------------------------------------------------
# Summary
# ------------------------------------------------------------------
def _summary(self) -> dict:
return {
'imagename': self.config.imagename,
'major_cycles': self._major_count,
'nparts': len(self._part_bundles),
}
# ======================================================================
# Internal helper
# ======================================================================
def _wait_all(actor_futures: list) -> list:
"""Block until all actor method futures complete."""
return [f.result() for f in actor_futures]