"""Data and image partitioning utilities.
Uses ``casatools.synthesisutils`` to divide data for continuum
(row-based) and cube (frequency-based) parallelism, and also
provides pure-Python fallback partitioners.
"""
from __future__ import annotations
import copy
import logging
import re
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from pclean.config import PcleanConfig
log = logging.getLogger(__name__)
_casatools = None
def _ct():
global _casatools
if _casatools is None:
import casatools as ct
_casatools = ct
return _casatools
# ======================================================================
# Frequency / quantity parsing helpers
# ======================================================================
_FREQ_UNITS: dict[str, float] = {
'hz': 1.0,
'khz': 1e3,
'mhz': 1e6,
'ghz': 1e9,
'thz': 1e12,
}
_QTY_RE = re.compile(r'^([+-]?[\d.eE+-]+)\s*([a-zA-Z/]+)$')
def _parse_freq_hz(val: int | float | str) -> float | None:
"""Parse a frequency quantity string to Hz.
Returns *None* if the value cannot be interpreted as a frequency.
"""
if isinstance(val, (int, float)):
return None # bare number = channel index, not a frequency
val = str(val).strip()
if not val:
return None
m = _QTY_RE.match(val)
if m is None:
return None
number, unit = float(m.group(1)), m.group(2).lower()
factor = _FREQ_UNITS.get(unit)
if factor is None:
return None
return number * factor
def _format_freq_ghz(hz: float) -> str:
"""Format a value in Hz as a GHz string."""
return f'{hz / 1e9:.10f}GHz'
# ======================================================================
# Continuum (row-based) partitioning
# ======================================================================
[docs]
def partition_continuum(
config: PcleanConfig,
nparts: int,
) -> list[dict]:
"""Partition data by visibility rows for parallel continuum imaging.
Uses ``synthesisutils.contdatapartition()`` to split each MS across
*nparts* workers. Each returned dict is a CASA-native parameter
bundle with selection narrowed to its row chunk and a unique partial
image name.
Args:
config: Full imaging configuration.
nparts: Number of partitions.
Returns:
One CASA-native bundle (dict) per worker.
"""
ct = _ct()
su = ct.synthesisutils()
base_selpars = config.to_casa_selpars()
try:
partselpars = su.contdatapartition(
selpars=base_selpars,
npart=nparts,
)
finally:
su.done()
base_bundle = config.to_casa_bundle()
result: list[dict] = []
for part_idx in range(nparts):
bundle = copy.deepcopy(base_bundle)
# contdatapartition returns a nested dict:
# {'0': {'ms0': {selpars}, 'ms1': ...}, '1': ...}
# outer key = partition index, inner keys = 'ms0', 'ms1', ...
part_key = str(part_idx)
for ms_key in sorted(base_selpars):
if part_key in partselpars and ms_key in partselpars[part_key]:
bundle['allselpars'][ms_key] = partselpars[part_key][ms_key]
# Override imagename in all CASA dict groups
new_name = f'{config.imagename}.part.{part_idx}'
bundle['allimpars']['0']['imagename'] = new_name
bundle['allnormpars']['0']['imagename'] = new_name
bundle['allgridpars']['0']['imagename'] = new_name
if 'allimages' in bundle['iterpars']:
bundle['iterpars']['allimages']['0']['imagename'] = new_name
result.append(bundle)
log.info('Partitioned continuum data into %d chunks', len(result))
return result
# ======================================================================
# Cube (frequency-based) partitioning
# ======================================================================
[docs]
def partition_cube(
config: PcleanConfig,
nparts: int,
) -> list[PcleanConfig]:
"""Partition the output cube by frequency channels for parallel cube imaging.
Uses ``synthesisutils.cubedataimagepartition()`` when possible,
falling back to an even-split heuristic.
Args:
config: Full imaging configuration.
nparts: Number of partitions.
Returns:
One ``PcleanConfig`` per worker, covering a non-overlapping
range of output channels.
"""
nchan = config.image.nchan
if nchan <= 0:
nchan = 1
# Try the casatools utility first
try:
return _partition_cube_via_su(config, nparts, nchan)
except Exception as exc:
log.debug('synthesisutils cube partition failed (%s); using even-split fallback', exc)
return _partition_cube_even(config, nparts, nchan)
def _partition_cube_via_su(
config: PcleanConfig,
nparts: int,
nchan: int,
) -> list[PcleanConfig]:
"""Partition cube using ``synthesisutils.cubedataimagepartition``.
Requires a coordinate system (csys) to be available in impars,
which is typically not the case before imaging starts.
"""
impars = config.to_casa_impars()
csys = impars['0'].get('csys', {})
if not csys:
raise RuntimeError(
'No coordinate system (csys) available; '
'cannot use synthesisutils for cube partitioning'
)
ct = _ct()
su = ct.synthesisutils()
selpars = config.to_casa_selpars()
try:
allpars = su.cubedataimagepartition(
selpars=selpars,
incsys=csys,
npart=nparts,
nchannel=nchan,
)
finally:
su.done()
result: list[PcleanConfig] = []
total_sub_nchan = 0
for pidx in range(nparts):
part_key = str(pidx)
if part_key not in allpars:
continue
part_rec = allpars[part_key]
sub_nc = part_rec.get('nchan', nchan)
sub_start = str(part_rec.get('start', pidx))
sub = config.make_subcube_config(sub_start, sub_nc, str(pidx))
total_sub_nchan += sub_nc
result.append(sub)
if nchan > 0 and total_sub_nchan != nchan:
raise RuntimeError(
f'synthesisutils partition produced {total_sub_nchan} total '
f'channels across {nparts} subcubes, expected {nchan}'
)
return result
def _resolve_frequency_grid(
config: PcleanConfig,
nchan: int,
) -> list[float] | None:
"""Compute the actual CASA output frequency grid for the full cube.
Creates a temporary ``synthesisimager``, calls ``selectdata`` +
``defineimage`` with the full *nchan*, and reads back the
per-channel frequencies that ``MSTransformRegridder::calcChanFreqs``
produces. This gives us the *exact* grid that a monolithic
``tclean(nchan=N)`` would use, so that subcube start frequencies
are consistent with the regridded data channels.
Returns:
A list of *nchan* channel centre frequencies in Hz, or *None*
if the grid could not be resolved.
"""
import shutil
import tempfile
ct = _ct()
# Build a unique temporary imagename so that concurrent calls
# (e.g. tests) do not collide.
tmpdir = tempfile.mkdtemp(prefix='pclean_freqgrid_')
imgname = f'{tmpdir}/_freqgrid'
si = None
sn = None
try:
si = ct.synthesisimager()
selpars = config.to_casa_selpars()
for ms_key in sorted(selpars):
selrec = dict(selpars[ms_key])
selrec.setdefault('usescratch', False)
selrec.setdefault('readonly', True)
log.debug(
'_resolve_frequency_grid selectdata[%s]: msname=%r type=%s',
ms_key,
selrec.get('msname'),
type(selrec.get('msname')).__name__,
)
si.selectdata(selpars=selrec)
# Disable cube gridding so makepsf runs in-process (no
# sub-imager / normalizer setup needed for the grid query).
# setcubegridding is only available in patched casatools;
# skip gracefully on unpatched builds — the monolithic path
# still works, just slightly slower for large nchan.
if hasattr(si, 'setcubegridding'):
si.setcubegridding(False)
impars = dict(config.to_casa_impars()['0'])
impars['imagename'] = imgname
# Use a tiny 32×32 spatial grid — we only need the spectral
# axis. Memory for the full-cube PSF is therefore just
# 32×32×1×nchan×4 bytes (e.g. ~16 MB for 4000 channels),
# so even with cube gridding disabled (monolithic allocation)
# this is safe on the coordinator.
impars['imsize'] = [32, 32]
impars['nchan'] = nchan
impars['restart'] = False
gridpars = dict(config.to_casa_gridpars()['0'])
gridpars['imagename'] = imgname
si.defineimage(impars=impars, gridpars=gridpars)
# We need makepsf to materialise the image on disk so we can
# read its coordinate system. A normalizer is required for
# makepsf to succeed (it gathers/divides PSF weights).
sn = ct.synthesisnormalizer()
normpars = dict(config.to_casa_normpars()['0'])
normpars['imagename'] = imgname
sn.setupnormalizer(normpars=normpars)
si.makepsf()
sn.gatherpsfweight()
sn.dividepsfbyweight()
ia = ct.image()
cs = None
try:
ia.open(imgname + '.psf')
cs = ia.coordsys()
shape = ia.shape()
n = int(shape[3])
freqs = [float(cs.toworld([0, 0, 0, i])['numeric'][3]) for i in range(n)]
finally:
if cs is not None:
try:
cs.done()
except Exception:
pass
try:
ia.done()
except Exception:
pass
if n != nchan:
log.warning(
'Frequency grid resolution produced %d channels '
'(expected %d) — falling back to arithmetic grid',
n, nchan,
)
return None
log.info(
'Resolved frequency grid: %d channels, '
'freq[0]=%.6f GHz, delta=%.6f MHz',
n, freqs[0] / 1e9,
(freqs[1] - freqs[0]) / 1e6 if n > 1 else 0.0,
)
return freqs
except Exception as exc:
log.debug(
'Could not resolve frequency grid via defineImage: %s',
exc,
)
return None
finally:
if si is not None:
try:
si.done()
except Exception:
pass
if sn is not None:
try:
sn.done()
except Exception:
pass
shutil.rmtree(tmpdir, ignore_errors=True)
def _partition_cube_even(
config: PcleanConfig,
nparts: int,
nchan: int,
) -> list[PcleanConfig]:
"""Simple even partition of channels across *nparts* workers.
When ``start`` and ``width`` are both frequency strings the output
grid is deterministic (``start_hz + i * width_hz``) and we compute
it arithmetically — no MS access or ``makepsf()`` needed. CASA's
``MSTransformRegridder::calcChanFreqs`` produces the same
equidistant grid for frequency-specified start/width, so the
arithmetic result is exact.
When ``start`` or ``width`` are channel-based integers we fall back
to ``_resolve_frequency_grid()`` (which calls ``defineimage`` +
``makepsf`` on a tiny 32×32 image) only if ``fracbw`` is needed
for ``briggsbwtaper``.
"""
if nchan <= 0:
log.warning('nchan unknown — falling back to single partition')
nparts = 1
nchan = -1
orig_start = config.image.start
orig_width = config.image.width
start_hz = _parse_freq_hz(orig_start)
width_hz = _parse_freq_hz(orig_width)
# When start and width are both in frequency units the output grid
# is just start_hz + i * width_hz — no need to materialise an
# image via makepsf(). Only fall back to the heavy
# _resolve_frequency_grid() path for channel-based start/width
# (where we genuinely need MS metadata to map channels → Hz).
resolved_freqs: list[float] | None = None
if start_hz is not None and width_hz is not None and nchan > 1:
resolved_freqs = [start_hz + i * width_hz for i in range(nchan)]
log.info(
'Resolved frequency grid (arithmetic): %d channels, '
'freq[0]=%.6f GHz, delta=%.6f MHz',
nchan, resolved_freqs[0] / 1e9,
width_hz / 1e6,
)
# For briggsbwtaper: pre-compute fracbw from the *full* cube so that
# single-channel subcubes inherit a valid fractional bandwidth.
# Without this, nchan=1 subcubes get fracbw=0 and CASA's
# BriggsCubeWeightor rejects the value.
if (
config.weight.weighting == 'briggsbwtaper'
and config.weight.fracbw is None
and nchan > 1
):
if resolved_freqs is not None and len(resolved_freqs) >= 2:
min_f = min(resolved_freqs)
max_f = max(resolved_freqs)
config.weight.fracbw = 2.0 * (max_f - min_f) / (max_f + min_f)
elif start_hz is not None and width_hz is not None:
end_f = start_hz + (nchan - 1) * width_hz
min_f = min(start_hz, end_f)
max_f = max(start_hz, end_f)
config.weight.fracbw = 2.0 * (max_f - min_f) / (max_f + min_f)
else:
# Integer start/width: resolve frequency grid to get fracbw only.
# Do not assign to resolved_freqs — the user asked for channel-based
# partitioning, so subcube starts should remain channel indices.
freqs = _resolve_frequency_grid(config, nchan)
if freqs is not None and len(freqs) >= 2:
min_f = min(freqs)
max_f = max(freqs)
config.weight.fracbw = 2.0 * (max_f - min_f) / (max_f + min_f)
if config.weight.fracbw is not None:
log.info(
'Pre-computed fracbw=%.6g for briggsbwtaper from full cube',
config.weight.fracbw,
)
# Greedy distribution: first (nchan % nparts) subcubes get one
# extra channel, matching CASA's C++ cubedataimagepartition.
chans_per_base = nchan // nparts
remainder = nchan % nparts
# Compute the frequency-domain channel width so that subcubes whose
# ``start`` is a frequency string also carry a matching ``width``.
# Without this, CASA rejects the mixed unit types (e.g. start in
# GHz but width as a bare channel count).
freq_width: str | None = None
if resolved_freqs is not None and len(resolved_freqs) >= 2:
freq_width = _format_freq_ghz(resolved_freqs[1] - resolved_freqs[0])
elif start_hz is not None and width_hz is not None:
freq_width = _format_freq_ghz(width_hz)
result: list[PcleanConfig] = []
chan_offset = 0
for i in range(nparts):
nc = chans_per_base + (1 if i < remainder else 0)
if nc <= 0:
break
if resolved_freqs is not None:
# Use the exact frequency from the resolved grid.
sub_start = _format_freq_ghz(resolved_freqs[chan_offset])
sub_width = freq_width
elif start_hz is not None and width_hz is not None:
sub_start_hz = start_hz + chan_offset * width_hz
sub_start = _format_freq_ghz(sub_start_hz)
sub_width = freq_width
else:
sub_start = str(chan_offset)
sub_width = None
log.info(
' subcube %d: start=%s nchan=%d (chan_offset=%d)',
i,
sub_start,
nc,
chan_offset,
)
sub = config.make_subcube_config(sub_start, nc, str(i), width=sub_width)
result.append(sub)
chan_offset += nc
log.info('Even-split cube partition: %d sub-cubes, total_chan=%d', len(result), chan_offset)
return result
# ======================================================================
# Helpers for partial-image naming
# ======================================================================
[docs]
def partial_image_name(base: str, part_index: int) -> str:
"""Return the partial-image path for a given partition index."""
workdir = f'{base}.workdirectory'
return f'{workdir}/{base}.n{part_index}'