Skip to content

Commit

Permalink
Merge pull request #78 from tjgalvin/maskoptions
Browse files Browse the repository at this point in the history
Created a MaskingOptions structure
  • Loading branch information
tjgalvin authored Apr 9, 2024
2 parents bf3d60d + c392435 commit eb3cbee
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 179 deletions.
1 change: 0 additions & 1 deletion flint/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def get_selfcal_options_from_yaml(input_yaml: Optional[Path] = None) -> Dict:
2: {"solint": "30s", "calmode": "p", "uvrange": ">235m", "nspw": 4},
3: {"solint": "60s", "calmode": "ap", "uvrange": ">235m", "nspw": 4},
4: {"solint": "30s", "calmode": "ap", "uvrange": ">235m", "nspw": 4},
5: {"solint": "30s", "calmode": "ap", "uvrange": ">235m", "nspw": 1},
}


Expand Down
198 changes: 64 additions & 134 deletions flint/masking.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
thought being towards FITS images.
"""

from __future__ import annotations

from argparse import ArgumentParser
from pathlib import Path
from typing import Optional, Tuple
from typing import NamedTuple, Optional

import numpy as np
from astropy.io import fits
Expand All @@ -17,13 +19,46 @@
binary_erosion as scipy_binary_erosion, # Rename to distinguish from skimage
)
from scipy.ndimage import label
from skimage.filters import butterworth
from skimage.morphology import binary_erosion

from flint.logging import logger
from flint.naming import FITSMaskNames, create_fits_mask_names


class MaskingOptions(NamedTuple):
"""Contains options for the creation of clean masks from some subject
image. Clipping levels specified are in units of RMS (or sigma). They
are NOT in absolute units.
"""

base_snr_clip: float = 4
"""A base clipping level to be used should other options not be activated"""
flood_fill: bool = True
"""Whether to attempt to flood fill when constructing a mask. This should be `True` for `grow_low_snr_islands` and `suppress_artefacts to have an effect. """
flood_fill_positive_seed_clip: float = 4.5
"""The clipping level to seed islands that will be grown to lower SNR"""
flood_fill_positive_flood_clip: float = 1.5
"""Clipping level used to grow seeded islands down to"""
suppress_artefacts: bool = True
"""Whether to attempt artefacts based on the presence of sigificant negatives"""
suppress_artefacts_negative_seed_clip: Optional[float] = 5
"""The significance level of a negative island for the sidelobe suppresion to be activated. This should be a positive number (the signal map is internally inverted)"""
suppress_artefacts_guard_negative_dilation: float = 40
"""The minimum positive signifance pixels should have to be guarded when attempting to suppress artefacts around bright sources"""
grow_low_snr_islands: bool = True
"""Whether to attempt to grow a mask to capture islands of low SNR (e.g. diffuse emission)"""
grow_low_snr_clip: float = 1.75
"""The minimum signifance levels of pixels to be to seed low SNR islands for consideration"""
grow_low_snr_island_size: int = 768
"""The number of pixels an island has to be for it to be accepted"""

def with_options(self, **kwargs) -> MaskingOptions:
"""Return a new instance of the MaskingOptions"""
_dict = self._asdict()
_dict.update(**kwargs)

return MaskingOptions(**_dict)


def extract_beam_mask_from_mosaic(
fits_beam_image_path: Path, fits_mosaic_mask_names: FITSMaskNames
) -> FITSMaskNames:
Expand Down Expand Up @@ -165,7 +200,8 @@ def reverse_negative_flood_fill(
signal: Optional[np.ndarray] = None,
positive_seed_clip: float = 4,
positive_flood_clip: float = 2,
negative_seed_clip: Optional[float] = 5,
suppress_artefacts: bool = False,
negative_seed_clip: float = 5,
guard_negative_dilation: float = 50,
grow_low_snr: Optional[float] = 2,
grow_low_island_size: int = 512,
Expand Down Expand Up @@ -205,6 +241,7 @@ def reverse_negative_flood_fill(
signal(Optional[np.ndarray], optional): A signal map. Defaults to None.
positive_seed_clip (float, optional): Initial clip of the mask before islands are grown. Defaults to 4.
positive_flood_clip (float, optional): Pixels above `positive_seed_clip` are dilated to this threshold. Defaults to 2.
suppress_artefacts (boo, optional): Attempt to suppress regions around presumed artefacts. Defaults to False.
negative_seed_clip (Optional[float], optional): Initial clip of negative pixels. This operation is on the inverted signal mask (so this value should be a positive number). If None this second operation is not performed. Defaults to 5.
guard_negative_dilation (float, optional): Positive pixels from the computed signal mask will be above this threshold to be protect from the negative island mask dilation. Defaults to 50.
grow_low__snr (Optional[float], optional): Attempt to grow islands of contigous pixels above thius low SNR ration. If None this is not performed. Defaults to 2.
Expand Down Expand Up @@ -239,14 +276,16 @@ def reverse_negative_flood_fill(
structure=np.ones((3, 3)),
)

# TODO: This function should be divided up
negative_dilated_mask = None

# Now do the same but on negative islands. The assumption here is that:
# - no genuine source of negative sky emission
# - negative islands are around bright sources with deconvolution/calibration errors
# - if there are brightish negative islands there is also positive brightish arteefact islands nearby
# For this reason the guard mask should be sufficently high to protect the
# main source but nuke the fask positive islands
negative_dilated_mask = None
if negative_seed_clip:
if suppress_artefacts:
negative_mask = negative_signal > negative_seed_clip
negative_dilated_mask = scipy_binary_dilation(
input=negative_mask,
Expand All @@ -270,110 +309,12 @@ def reverse_negative_flood_fill(
return positive_dilated_mask.astype(np.int32)


def create_snr_mask_wbutter_from_fits(
fits_image_path: Path,
fits_rms_path: Path,
fits_bkg_path: Path,
create_signal_fits: bool = False,
min_snr: float = 5,
connectivity_shape: Tuple[int, int] = (4, 4),
overwrite: bool = True,
) -> FITSMaskNames:
"""Create a mask for an input FITS image based on a signal to noise given a corresponding pair of RMS and background FITS images.
Internally the signal image is computed as something akin to:
> signal = (image - background) / rms
Before deriving a signal map the image is first smoothed using a butterworth filter, and a
crude rescaling factor is applied based on the ratio of the maximum pixel values before and
after the smoothing is applied.
This is done in a staged manner to minimise the number of (potentially large) images
held in memory.
Each of the input images needs to share the same shape. This means that compression
features offered by some tooling (e.g. BANE --compress) can not be used.
Once the signal map as been computed, all pixels below ``min_snr`` are flagged. The resulting
islands then have a binary erosion applied to contract the resultingn islands.
Args:
fits_image_path (Path): Path to the FITS file containing an image
fits_rms_path (Path): Path to the FITS file with an RMS image corresponding to ``fits_image_path``
fits_bkg_path (Path): Path to the FITS file with an baclground image corresponding to ``fits_image_path``
create_signal_fits (bool, optional): Create an output signal map. Defaults to False.
min_snr (float, optional): Minimum signal-to-noise ratio for the masking to include a pixel. Defaults to 3.5.
connectivity_shape (Tuple[int, int], optional): The connectivity matrix used in the scikit-image binary erosion applied to the mask. Defaults to (4, 4).
overwrite (bool): Passed to `fits.writeto`, and will overwrite files should they exist. Defaults to True.
Returns:
FITSMaskNames: Container describing the signal and mask FITS image paths. If ``create_signal_path`` is None, then the ``signal_fits`` attribute will be None.
"""
logger.info(f"Creating a mask image with SNR>{min_snr:.2f}")
mask_names = create_fits_mask_names(
fits_image=fits_image_path, include_signal_path=create_signal_fits
)

with fits.open(fits_image_path) as fits_image:
fits_header = fits_image[0].header

image_max = np.nanmax(fits_image[0].data)
image_butter = butterworth(
np.nan_to_num(np.squeeze(fits_image[0].data)), 0.045, high_pass=False
)

butter_max = np.nanmax(image_butter)
scale_ratio = image_max / butter_max
logger.info(f"Scaling smoothed image by {scale_ratio:.4f}")
image_butter *= image_max / butter_max

with fits.open(fits_bkg_path) as fits_bkg:
logger.info("Subtracting background")
signal_data = image_butter - np.squeeze(fits_bkg[0].data)

with fits.open(fits_rms_path) as fits_rms:
logger.info("Dividing by RMS")
signal_data /= np.squeeze(fits_rms[0].data)

if create_signal_fits:
logger.info(f"Writing {mask_names.signal_fits}")
fits.writeto(
filename=mask_names.signal_fits,
data=signal_data,
header=fits_header,
overwrite=overwrite,
)

# Following the help in wsclean:
# WSClean accepts masks in CASA format and in fits file format. A mask is a
# normal, single polarization image file, where all zero values are interpreted
# as being not masked, and all non-zero values are interpreted as masked. In the
# case of a fits file, the file may either contain a single frequency or it may
# contain a cube of images.
logger.info(f"Clipping using a {min_snr=}")
mask_data = (signal_data > min_snr).astype(np.int32)

logger.info(f"Applying binary erosion with {connectivity_shape=}")
mask_data = binary_erosion(mask_data, np.ones(connectivity_shape))

logger.info(f"Writing {mask_names.mask_fits}")
fits.writeto(
filename=mask_names.mask_fits,
data=mask_data.astype(np.int32),
header=fits_header,
overwrite=overwrite,
)

return mask_names


def create_snr_mask_from_fits(
fits_image_path: Path,
fits_rms_path: Path,
fits_bkg_path: Path,
masking_options: MaskingOptions,
create_signal_fits: bool = False,
min_snr: float = 3.5,
attempt_reverse_nergative_flood_fill: bool = True,
overwrite: bool = True,
) -> FITSMaskNames:
"""Create a mask for an input FITS image based on a signal to noise given a corresponding pair of RMS and background FITS images.
Expand All @@ -393,15 +334,13 @@ def create_snr_mask_from_fits(
fits_image_path (Path): Path to the FITS file containing an image
fits_rms_path (Path): Path to the FITS file with an RMS image corresponding to ``fits_image_path``
fits_bkg_path (Path): Path to the FITS file with an baclground image corresponding to ``fits_image_path``
masking_options (MaskingOptions): Configurables on the masking operation procedure.
create_signal_fits (bool, optional): Create an output signal map. Defaults to False.
min_snr (float, optional): Minimum signal-to-noise ratio for the masking to include a pixel. Defaults to 3.5.
attempt_negative_flood_fill (bool): Attempt to filter out negative sidelobes from the bask. See `reverse_negative_flood_fill`. Defaults to True.
overwrite (bool): Passed to `fits.writeto`, and will overwrite files should they exist. Defaults to True.
Returns:
FITSMaskNames: Container describing the signal and mask FITS image paths. If ``create_signal_path`` is None, then the ``signal_fits`` attribute will be None.
"""
logger.info(f"Creating a mask image with SNR>{min_snr:.2f}")
mask_names = create_fits_mask_names(
fits_image=fits_image_path, include_signal_path=create_signal_fits
)
Expand Down Expand Up @@ -431,20 +370,20 @@ def create_snr_mask_from_fits(
# as being not masked, and all non-zero values are interpreted as masked. In the
# case of a fits file, the file may either contain a single frequency or it may
# contain a cube of images.
if attempt_reverse_nergative_flood_fill:
if masking_options.flood_fill:
mask_data = reverse_negative_flood_fill(
signal=np.squeeze(signal_data),
positive_seed_clip=4,
positive_flood_clip=1.5,
negative_seed_clip=4,
guard_negative_dilation=30,
grow_low_snr=1.75,
grow_low_island_size=768,
positive_seed_clip=masking_options.flood_fill_positive_seed_clip,
positive_flood_clip=masking_options.flood_fill_positive_flood_clip,
negative_seed_clip=masking_options.suppress_artefacts_negative_seed_clip,
guard_negative_dilation=masking_options.suppress_artefacts_guard_negative_dilation,
grow_low_snr=masking_options.grow_low_snr_clip,
grow_low_island_size=masking_options.grow_low_snr_island_size,
)
mask_data = mask_data.reshape(signal_data.shape)
else:
logger.info(f"Clipping using a {min_snr=}")
mask_data = (signal_data > min_snr).astype(np.int32)
logger.info(f"Clipping using a {masking_options.base_snr_clip=}")
mask_data = (signal_data > masking_options.base_snr_clip).astype(np.int32)

logger.info(f"Writing {mask_names.mask_fits}")
fits.writeto(
Expand Down Expand Up @@ -529,23 +468,14 @@ def cli():
args = parser.parse_args()

if args.mode == "snrmask":
if args.user_butterworth:
create_snr_mask_wbutter_from_fits(
fits_image_path=args.image,
fits_rms_path=args.rms,
fits_bkg_path=args.bkg,
create_signal_fits=args.save_signal,
min_snr=args.min_snr,
connectivity_shape=tuple(args.connectivity_shape),
)
else:
create_snr_mask_from_fits(
fits_image_path=args.image,
fits_rms_path=args.rms,
fits_bkg_path=args.bkg,
create_signal_fits=args.save_signal,
min_snr=args.min_snr,
)
masking_options = MaskingOptions(base_snr_clip=args.min_snr)
create_snr_mask_from_fits(
fits_image_path=args.image,
fits_rms_path=args.rms,
fits_bkg_path=args.bkg,
create_signal_fits=args.save_signal,
masking_options=masking_options,
)

elif args.mode == "extractmask":
extract_beam_mask_from_mosaic(
Expand Down
4 changes: 1 addition & 3 deletions flint/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,9 @@ class FieldOptions(NamedTuple):
"""Primary beam attentuation cutoff to use during linmos"""
use_preflagger: bool = True
"""Whether to apply (or search for solutions with) bandpass solutions that have gone through the preflagging operations"""
use_smoothed: bool = True
use_smoothed: bool = False
"""Whether to apply (or search for solutions with) a bandpass smoothing operation applied"""
use_beam_masks: bool = True
"""Construct beam masks from MFS images to use for the next round of imaging. """
use_beam_masks_from: int = 2
"""If `use_beam_masks` is True, start using them from this round of self-calibration"""
use_beam_mask_wbutterworth: bool = False
"""If `use_beam_masks` is True, this will specify whether a Butterworth filter is used to smooth the image before the S/N clip is applied"""
35 changes: 15 additions & 20 deletions flint/prefect/common/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from flint.imager.wsclean import ImageSet, WSCleanCommand, wsclean_imager
from flint.logging import logger
from flint.masking import (
MaskingOptions,
create_snr_mask_from_fits,
create_snr_mask_wbutter_from_fits,
extract_beam_mask_from_mosaic,
)
from flint.ms import MS, preprocess_askap_ms, rename_column_in_ms, split_by_field
Expand Down Expand Up @@ -462,7 +462,7 @@ def task_create_image_mask_model(
image: Union[LinmosCommand, ImageSet, WSCleanCommand],
image_products: AegeanOutputs,
min_snr: Optional[float] = 3.5,
with_butterworth: bool = False,
update_masking_options: Optional[Dict[str, Any]] = None,
) -> FITSMaskNames:
"""Create a mask from a linmos image, with the intention of providing it as a clean mask
to an appropriate imager. This is derived using a simple signal to noise cut.
Expand All @@ -471,7 +471,8 @@ def task_create_image_mask_model(
linmos_parset (LinmosCommand): Linmos command and associated meta-data
image_products (AegeanOutputs): Images of the RMS and BKG
min_snr (float, optional): The minimum S/N a pixel should be for it to be included in the clean mask.
with_butterworth (bool, optional): whether to taper the input image with a Butterworth filter before masking.
update_masking_options (Optional[Dict[str,Any]], optional): Updated options supplied to the default MaskingOptions. Defaults to None.
Raises:
ValueError: Raised when ``image_products`` are not known
Expand Down Expand Up @@ -502,23 +503,17 @@ def task_create_image_mask_model(
logger.info(f"Using {source_rms=}")
logger.info(f"Using {source_bkg=}")

if with_butterworth:
mask_names = create_snr_mask_wbutter_from_fits(
fits_image_path=source_image,
fits_bkg_path=source_bkg,
fits_rms_path=source_rms,
create_signal_fits=True,
min_snr=min_snr,
connectivity_shape=(3, 3),
)
else:
mask_names = create_snr_mask_from_fits(
fits_image_path=source_image,
fits_bkg_path=source_bkg,
fits_rms_path=source_rms,
create_signal_fits=True,
min_snr=min_snr,
)
masking_options = MaskingOptions()
if update_masking_options:
masking_options = masking_options.with_options(**update_masking_options)

mask_names = create_snr_mask_from_fits(
fits_image_path=source_image,
fits_bkg_path=source_bkg,
fits_rms_path=source_rms,
masking_options=masking_options,
create_signal_fits=True,
)

logger.info(f"Created {mask_names.mask_fits}")

Expand Down
Loading

0 comments on commit eb3cbee

Please sign in to comment.