Skip to content

Commit

Permalink
Form wsclean sub-band cubes, convol and then linmos (#173)
Browse files Browse the repository at this point in the history
* workings for smoothing cubes started

* fixed cube beam getting / tests

* using from radio_beam method

* added test, smoothing beams to common res / tests

* removed strange test

* assert check for cubes in CLI

* bounding boxes to handle cubes / tests

* add cube weights to linmos and tests

* updated cli help

* updated changelog

* added --pol-axis to CLI

* added parent to holofile bind_dirs

* reforming the datacube

* writing out correct beam extension / linmos weights txt update

* comment

* update correct header card

* added initial linmos cubes

* removed type from --coadd-cubes

* added unmapped

* added  field_options.beam_cutoff appropriately

* check for filter

* added a tag for cubes

* patched bounding box check

* added for loop in convol

* added a stride

* 0 weights, nan check, strides to 8

* nan data checkl

* added a few checks, correct conv cube path return

* released numpy version constraint

* python3.11 because of spectralcube

* numpy version regress

* new test

* added to changelog

---------

Co-authored-by: tgalvin <[email protected]>
  • Loading branch information
tjgalvin and tgalvin authored Sep 24, 2024
1 parent d59d357 commit 6848f6d
Show file tree
Hide file tree
Showing 14 changed files with 852 additions and 84 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-lint-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ jobs:

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.12
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
python-version: "3.12"
python-version: "3.11"
- name: Install casacore and boost
run: |
sudo apt install -y build-essential libcfitsio-dev liblapack-dev libboost-python-dev python3-dev wcslib-dev casacore-dev
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

# dev

- added in convolving of cubes to common resolution across channels
- cubes are supported when computing the yandasoft linmos weights and trimming
- `--coadd-cubes` option added to co-add cubes on the final imaging round
together to form a single field spectral cube

# 0.2.6

- if `-temp-dir` used in wsclean then imaging products are produced here and
Expand Down
214 changes: 170 additions & 44 deletions flint/coadd/linmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from argparse import ArgumentParser
from pathlib import Path
from typing import Collection, List, NamedTuple, Optional, Tuple
from typing import Collection, List, NamedTuple, Optional, Tuple, Literal

import numpy as np
from astropy.io import fits
Expand Down Expand Up @@ -37,21 +37,21 @@ class BoundingBox(NamedTuple):
ymax: int
"""Maximum y pixel"""
original_shape: Tuple[int, int]
"""The original shape of the image"""
"""The original shape of the image. If constructed against a cube this is the shape of a single plane."""


def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> BoundingBox:
"""Construct a bounding box around finite pixels for a 2D image. This does not
support cube type images.
If ``is_mask` is ``False``, the ``image_data`` will be masked internally using ``numpy.isfinite``.
def _create_bound_box_plane(
image_data: np.ndarray, is_masked: bool = False
) -> Optional[BoundingBox]:
"""Create a bounding box around pixels in a 2D image. If all
pixels are not valid, then ``None`` is returned.
Args:
image_data (np.ndarray): The image data that will have a bounding box constructed for.
is_masked (bool, optional): if this is ``True`` the ``image_data`` are treated as a boolean mask array. Defaults to False.
image_data (np.ndarray): The 2D ina==mage to construct a bounding box around
is_masked (bool, optional): Whether to treat the image as booleans or values. Defaults to False.
Returns:
BoundingBox: The tight bounding box around pixels.
Optional[BoundingBox]: None if no valid pixels, a bounding box with the (xmin,xmax,ymin,ymax) of valid pixels
"""
assert (
len(image_data.shape) == 2
Expand All @@ -60,6 +60,10 @@ def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> Boundin
# First convert to a boolean array
image_valid = image_data if is_masked else np.isfinite(image_data)

if not any(image_valid.reshape(-1)):
logger.info("No pixels to creating bounding box for")
return None

# Then make them 1D arrays
x_valid = np.any(image_valid, axis=1)
y_valid = np.any(image_valid, axis=0)
Expand All @@ -68,6 +72,62 @@ def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> Boundin
xmin, xmax = np.where(x_valid)[0][[0, -1]]
ymin, ymax = np.where(y_valid)[0][[0, -1]]

return BoundingBox(
xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, original_shape=image_data.shape[-2:]
)


def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> BoundingBox:
"""Construct a bounding box around finite pixels for a 2D image.
If a cube ids provided, the bounding box is constructed from pixels
as broadcast across all of the non-spatial dimensions. That is to
say the single bounding box can be projected across all channel/stokes
channels
If ``is_mask` is ``False``, the ``image_data`` will be masked internally using ``numpy.isfinite``.
Args:
image_data (np.ndarray): The image data that will have a bounding box constructed for.
is_masked (bool, optional): if this is ``True`` the ``image_data`` are treated as a boolean mask array. Defaults to False.
Returns:
BoundingBox: The tight bounding box around pixels.
"""
reshaped_image_data = image_data.reshape((-1, *image_data.shape[-2:]))
logger.info(f"New image shape {reshaped_image_data.shape} from {image_data.shape}")

bounding_boxes = [
_create_bound_box_plane(image_data=image, is_masked=is_masked)
for image in reshaped_image_data
]
bounding_boxes = [bb for bb in bounding_boxes if bb is not None]

if len(bounding_boxes) == 0:
logger.info("No valid bounding box found. Constructing one for all pixels")
return BoundingBox(
xmin=0,
xmax=image_data.shape[-1] - 1,
ymin=0,
ymax=image_data.shape[-2] - 1,
original_shape=tuple(image_data.shape[-2:]), # type: ignore
)
elif len(bounding_boxes) == 1:
assert bounding_boxes[0] is not None, "This should not happen"
return bounding_boxes[0]

assert all([bb is not None for bb in bounding_boxes])

logger.info(
f"Boounding boxes across {len(bounding_boxes)} constructed. Finsing limits. "
)
# The type ignores below are to avoid mypy believe bound_boxes could
# include None. The above checks should be sufficient
xmin = min([bb.xmin for bb in bounding_boxes]) # type: ignore
xmax = max([bb.xmax for bb in bounding_boxes]) # type: ignore
ymin = min([bb.ymin for bb in bounding_boxes]) # type: ignore
ymax = max([bb.ymax for bb in bounding_boxes]) # type: ignore

return BoundingBox(
xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, original_shape=image_data.shape
)
Expand Down Expand Up @@ -100,15 +160,17 @@ def trim_fits_image(
data = fits_image[0].data # type: ignore
logger.info(f"Original data shape: {data.shape}")

image_shape = (data.shape[-2], data.shape[-1])
image_shape = data.shape[-2:]
logger.info(f"The image dimensions are: {image_shape}")

if not bounding_box:
logger.info("Constructing a new bounding box")
bounding_box = create_bound_box(
image_data=np.squeeze(data), is_masked=False
)
logger.info(f"Constructed {bounding_box=}")
else:
if image_shape != bounding_box.original_shape:
if image_shape != bounding_box.original_shape[-2:]:
raise ValueError(
f"Bounding box constructed against {bounding_box.original_shape}, but being applied to {image_shape=}"
)
Expand All @@ -130,9 +192,53 @@ def trim_fits_image(
return TrimImageResult(path=image_path, bounding_box=bounding_box)


def get_image_weight(
image_path: Path, mode: str = "mad", image_slice: int = 0
def _get_image_weight_plane(
image_data: np.ndarray, mode: Literal["std", "mad"] = "mad", stride: int = 4
) -> float:
"""Extract the inverse variance weight for an input plane of data
Modes are 'std' or 'mad'.
Args:
image_data (np.ndarray): Data to consider
mode (str, optional): Statistic computation mode. Defaults to "mad".
stride (int, optional): Include every n'th pixel when computing the weight. '1' includes all pixels. Defaults to 1.
Raises:
ValueError: Raised when modes unknown
Returns:
float: The inverse variance weight computerd
"""

weight_modes = ("mad", "std")
assert (
mode in weight_modes
), f"Invalid {mode=} specified. Available modes: {weight_modes}"

# remove non-finite numbers that would ruin the statistic
image_data = image_data[np.isfinite(image_data)][::stride]

if np.all(~np.isfinite(image_data)):
return 0.0

if mode == "mad":
median = np.median(image_data)
mad = np.median(np.abs(image_data - median))
weight = 1.0 / mad**2
elif mode == "std":
std = np.std(image_data)
weight = 1.0 / std**2
else:
raise ValueError(f"Invalid {mode=} specified. Available modes: {weight_modes}")

float_weight = float(weight)
return float_weight if np.isfinite(float_weight) else 0.0


def get_image_weight(
image_path: Path, mode: str = "mad", stride: int = 1, image_slice: int = 0
) -> List[float]:
"""Compute an image weight supplied to linmos, which is used for optimally
weighting overlapping images. Supported modes are 'mad' and 'mtd', which
simply resolve to their numpy equivalents.
Expand All @@ -142,55 +248,57 @@ def get_image_weight(
the same way, it does not necessarily have to correspond to an optimatelly
calculated RMS.
The stride parameter will only include every N'th pixel when computing the
weights. A smaller set of pixels will reduce the time required to calculate
the weights, but may come at the cost of accuracy with large values.
Args:
image (Path): The path to the image fits file to inspect.
mode (str, optional): Which mode should be used when calculating the weight. Defaults to 'mad'.
stride (int, optional): Include every n'th pixel when computing the weight. '1' includes all pixels. Defaults to 1.
image_slice (int, optional): The image slice in the HDU list of the `image` fits file to inspect. Defaults to 0.
Raises:
ValueError: Raised when a mode is requested but does not exist
Returns:
float: The weight to supply to linmos
List[float]: The weight per channel to supply to linmos
"""

logger.debug(
f"Computing linmos weight using {mode=}, {image_slice=} for {image_path}. "
)
weight_modes = ("mad", "std")
assert (
mode in weight_modes
), f"Invalid {mode=} specified. Available modes: {weight_modes}"

weights: List[float] = []
with fits.open(image_path, memmap=True) as in_fits:
image_data = in_fits[image_slice].data # type: ignore

assert len(
image_data.shape
assert (
len(image_data.shape) >= 2
), f"{len(image_data.shape)=} is less than two. Is this really an image?"

# remove non-finite numbers
image_data = image_data[np.isfinite(image_data)]

logger.debug(f"Data shape is: {image_data.shape}")
if mode == "mad":
median = np.median(image_data)
mad = np.median(np.abs(image_data - median))
weight = 1.0 / mad**2
elif mode == "std":
std = np.std(image_data)
weight = 1.0 / std**2
else:
raise ValueError(
f"Invalid {mode=} specified. Available modes: {weight_modes}"
)
image_shape = image_data.shape[-2:]
image_data = (
image_data.reshape((-1, *image_shape))
if len(image_data.shape)
else image_data
)

assert (
len(image_data.shape) == 3
), f"Expected to have shape (chan, dec, ra), got {image_data.shape}"

for idx, chan_image_data in enumerate(image_data):
weight = _get_image_weight_plane(image_data=chan_image_data, stride=stride)
logger.info(f"Channel {idx} {weight=:.3f} for {image_path}")

weights.append(weight)

logger.info(f"Weight {weight:.3f} for {image_path}")
return float(weight)
return weights


def generate_weights_list_and_files(
image_paths: Collection[Path], mode: str = "mad"
image_paths: Collection[Path], mode: str = "mad", stride: int = 1
) -> str:
"""Generate the expected linmos weight files, and construct an appropriate
string that can be embedded into a linmos partset. These weights files will
Expand All @@ -207,6 +315,10 @@ def generate_weights_list_and_files(
the moment it is only intended to work on MFS images. It __is not__ currently
intended to be used on image cubes.
The stride parameter will only include every N'th pixel when computing the
weights. A smaller set of pixels will reduce the time required to calculate
the weights, but may come at the cost of accuracy with large values.
Args:
image_paths (Collection[Path]): Images to iterate over to create a corresponding weights.txt file.
mode (str, optional): The mode to use when calling get_image_weight
Expand All @@ -232,8 +344,12 @@ def generate_weights_list_and_files(
with open(weight_file, "w") as out_file:
logger.info(f"Writing {weight_file}")
out_file.write("#Channel Weight\n")
image_weight = get_image_weight(image_path=image, mode=mode)
out_file.write(f"0 {image_weight}\n")
image_weights = get_image_weight(image_path=image, mode=mode, stride=stride)
weights = "\n".join(
[f"{idx} {weight}" for idx, weight in enumerate(image_weights)]
)
out_file.write(weights)
out_file.write("\n") # Required for linmos to properly process weights

weight_str = [
str(weight_file) for weight_file in weight_file_list if weight_file.exists()
Expand Down Expand Up @@ -355,7 +471,9 @@ def generate_linmos_parameter_set(
# quality. In reality, this should be updated to provide a RMS noise
# estimate per-pixel of each image.
if weight_list is None:
weight_list = generate_weights_list_and_files(image_paths=images, mode="mad")
weight_list = generate_weights_list_and_files(
image_paths=images, mode="mad", stride=8
)

beam_order_strs = [str(extract_beam_from_name(str(p.name))) for p in images]
beam_order_list = "[" + ",".join(beam_order_strs) + "]"
Expand Down Expand Up @@ -446,7 +564,7 @@ def linmos_images(
linmos_parset.absolute().parent
]
if holofile:
bind_dirs.append(holofile.absolute())
bind_dirs.append(holofile.absolute().parent)

run_singularity_command(
image=container, command=linmos_cmd_str, bind_dirs=bind_dirs
Expand Down Expand Up @@ -501,6 +619,12 @@ def get_parser() -> ArgumentParser:
default=None,
help="Path to the holography FITS cube used for primary beam corrections",
)
parset_parser.add_argument(
"--pol-axis",
type=float,
default=2 * np.pi / 8,
help="The rotation in radians of the third-axis of the obseration. Defaults to PI/4",
)
parset_parser.add_argument(
"--yandasoft-container",
type=Path,
Expand All @@ -509,7 +633,7 @@ def get_parser() -> ArgumentParser:
)

trim_parser = subparsers.add_parser(
"trim", help="Generate a yandasoft linmos parset"
"trim", help="Remove blank border of FITS image"
)

trim_parser.add_argument(
Expand All @@ -532,6 +656,7 @@ def cli() -> None:
linmos_names=linmos_names,
weight_list=args.weight_list,
holofile=args.holofile,
pol_axis=args.pol_axis,
)
else:
linmos_images(
Expand All @@ -541,6 +666,7 @@ def cli() -> None:
weight_list=args.weight_list,
holofile=args.holofile,
container=args.yandasoft_container,
pol_axis=args.pol_axis,
)
elif args.mode == "trim":
images = args.images
Expand Down
Loading

0 comments on commit 6848f6d

Please sign in to comment.