Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Form wsclean sub-band cubes, convol and then linmos #173

Merged
merged 33 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
34209b3
workings for smoothing cubes started
Sep 13, 2024
88e8450
fixed cube beam getting / tests
Sep 14, 2024
094b2ab
using from radio_beam method
Sep 14, 2024
b5f52ca
added test, smoothing beams to common res / tests
Sep 15, 2024
f492472
removed strange test
Sep 15, 2024
cb19c73
assert check for cubes in CLI
Sep 15, 2024
7ec5f56
bounding boxes to handle cubes / tests
Sep 16, 2024
52b26b3
add cube weights to linmos and tests
Sep 16, 2024
0c25fa2
updated cli help
Sep 16, 2024
9e24ea4
updated changelog
Sep 16, 2024
a6d63d7
added --pol-axis to CLI
Sep 16, 2024
55428a0
added parent to holofile bind_dirs
Sep 16, 2024
12d27fd
reforming the datacube
Sep 16, 2024
324e825
writing out correct beam extension / linmos weights txt update
Sep 17, 2024
6312cde
comment
Sep 17, 2024
aecd201
update correct header card
Sep 17, 2024
27dd197
added initial linmos cubes
Sep 17, 2024
16a4f8d
removed type from --coadd-cubes
Sep 17, 2024
6b4cc8a
added unmapped
Sep 17, 2024
05ebc9c
added field_options.beam_cutoff appropriately
Sep 17, 2024
fb345b6
check for filter
Sep 17, 2024
5b1b83b
added a tag for cubes
Sep 17, 2024
4bdd3f4
patched bounding box check
Sep 18, 2024
17152e9
added for loop in convol
Sep 19, 2024
2fdfede
added a stride
Sep 20, 2024
65af5cb
0 weights, nan check, strides to 8
Sep 20, 2024
73dc2a8
nan data checkl
Sep 21, 2024
7cadf6a
added a few checks, correct conv cube path return
Sep 21, 2024
ef2c0ac
released numpy version constraint
Sep 22, 2024
b78726a
python3.11 because of spectralcube
Sep 22, 2024
f9378ba
numpy version regress
Sep 22, 2024
d5e2bc0
new test
Sep 22, 2024
7eea24d
added to changelog
Sep 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/python-lint-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ jobs:

steps:
- uses: actions/checkout@v3
- name: Set up Python 3.12
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
python-version: "3.12"
python-version: "3.11"
- name: Install casacore and boost
run: |
sudo apt install -y build-essential libcfitsio-dev liblapack-dev libboost-python-dev python3-dev wcslib-dev casacore-dev
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

# dev

- added in convolving of cubes to common resolution across channels
- cubes are supported when computing the yandasoft linmos weights and trimming
- `--coadd-cubes` option added to co-add cubes on the final imaging round
together to form a single field spectral cube

# 0.2.6

- if `-temp-dir` used in wsclean then imaging products are produced here and
Expand Down
214 changes: 170 additions & 44 deletions flint/coadd/linmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from argparse import ArgumentParser
from pathlib import Path
from typing import Collection, List, NamedTuple, Optional, Tuple
from typing import Collection, List, NamedTuple, Optional, Tuple, Literal

import numpy as np
from astropy.io import fits
Expand Down Expand Up @@ -37,21 +37,21 @@ class BoundingBox(NamedTuple):
ymax: int
"""Maximum y pixel"""
original_shape: Tuple[int, int]
"""The original shape of the image"""
"""The original shape of the image. If constructed against a cube this is the shape of a single plane."""


def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> BoundingBox:
"""Construct a bounding box around finite pixels for a 2D image. This does not
support cube type images.

If ``is_mask` is ``False``, the ``image_data`` will be masked internally using ``numpy.isfinite``.
def _create_bound_box_plane(
image_data: np.ndarray, is_masked: bool = False
) -> Optional[BoundingBox]:
"""Create a bounding box around pixels in a 2D image. If all
pixels are not valid, then ``None`` is returned.

Args:
image_data (np.ndarray): The image data that will have a bounding box constructed for.
is_masked (bool, optional): if this is ``True`` the ``image_data`` are treated as a boolean mask array. Defaults to False.
image_data (np.ndarray): The 2D ina==mage to construct a bounding box around
is_masked (bool, optional): Whether to treat the image as booleans or values. Defaults to False.

Returns:
BoundingBox: The tight bounding box around pixels.
Optional[BoundingBox]: None if no valid pixels, a bounding box with the (xmin,xmax,ymin,ymax) of valid pixels
"""
assert (
len(image_data.shape) == 2
Expand All @@ -60,6 +60,10 @@ def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> Boundin
# First convert to a boolean array
image_valid = image_data if is_masked else np.isfinite(image_data)

if not any(image_valid.reshape(-1)):
logger.info("No pixels to creating bounding box for")
return None

# Then make them 1D arrays
x_valid = np.any(image_valid, axis=1)
y_valid = np.any(image_valid, axis=0)
Expand All @@ -68,6 +72,62 @@ def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> Boundin
xmin, xmax = np.where(x_valid)[0][[0, -1]]
ymin, ymax = np.where(y_valid)[0][[0, -1]]

return BoundingBox(
xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, original_shape=image_data.shape[-2:]
)


def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> BoundingBox:
"""Construct a bounding box around finite pixels for a 2D image.

If a cube ids provided, the bounding box is constructed from pixels
as broadcast across all of the non-spatial dimensions. That is to
say the single bounding box can be projected across all channel/stokes
channels

If ``is_mask` is ``False``, the ``image_data`` will be masked internally using ``numpy.isfinite``.

Args:
image_data (np.ndarray): The image data that will have a bounding box constructed for.
is_masked (bool, optional): if this is ``True`` the ``image_data`` are treated as a boolean mask array. Defaults to False.

Returns:
BoundingBox: The tight bounding box around pixels.
"""
reshaped_image_data = image_data.reshape((-1, *image_data.shape[-2:]))
logger.info(f"New image shape {reshaped_image_data.shape} from {image_data.shape}")

bounding_boxes = [
_create_bound_box_plane(image_data=image, is_masked=is_masked)
for image in reshaped_image_data
]
bounding_boxes = [bb for bb in bounding_boxes if bb is not None]

if len(bounding_boxes) == 0:
logger.info("No valid bounding box found. Constructing one for all pixels")
return BoundingBox(
xmin=0,
xmax=image_data.shape[-1] - 1,
ymin=0,
ymax=image_data.shape[-2] - 1,
original_shape=tuple(image_data.shape[-2:]), # type: ignore
)
elif len(bounding_boxes) == 1:
assert bounding_boxes[0] is not None, "This should not happen"
return bounding_boxes[0]

assert all([bb is not None for bb in bounding_boxes])

logger.info(
f"Boounding boxes across {len(bounding_boxes)} constructed. Finsing limits. "
)
# The type ignores below are to avoid mypy believe bound_boxes could
# include None. The above checks should be sufficient
xmin = min([bb.xmin for bb in bounding_boxes]) # type: ignore
xmax = max([bb.xmax for bb in bounding_boxes]) # type: ignore
ymin = min([bb.ymin for bb in bounding_boxes]) # type: ignore
ymax = max([bb.ymax for bb in bounding_boxes]) # type: ignore

return BoundingBox(
xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, original_shape=image_data.shape
)
Expand Down Expand Up @@ -100,15 +160,17 @@ def trim_fits_image(
data = fits_image[0].data # type: ignore
logger.info(f"Original data shape: {data.shape}")

image_shape = (data.shape[-2], data.shape[-1])
image_shape = data.shape[-2:]
logger.info(f"The image dimensions are: {image_shape}")

if not bounding_box:
logger.info("Constructing a new bounding box")
bounding_box = create_bound_box(
image_data=np.squeeze(data), is_masked=False
)
logger.info(f"Constructed {bounding_box=}")
else:
if image_shape != bounding_box.original_shape:
if image_shape != bounding_box.original_shape[-2:]:
raise ValueError(
f"Bounding box constructed against {bounding_box.original_shape}, but being applied to {image_shape=}"
)
Expand All @@ -130,9 +192,53 @@ def trim_fits_image(
return TrimImageResult(path=image_path, bounding_box=bounding_box)


def get_image_weight(
image_path: Path, mode: str = "mad", image_slice: int = 0
def _get_image_weight_plane(
image_data: np.ndarray, mode: Literal["std", "mad"] = "mad", stride: int = 4
) -> float:
"""Extract the inverse variance weight for an input plane of data

Modes are 'std' or 'mad'.

Args:
image_data (np.ndarray): Data to consider
mode (str, optional): Statistic computation mode. Defaults to "mad".
stride (int, optional): Include every n'th pixel when computing the weight. '1' includes all pixels. Defaults to 1.

Raises:
ValueError: Raised when modes unknown

Returns:
float: The inverse variance weight computerd
"""

weight_modes = ("mad", "std")
assert (
mode in weight_modes
), f"Invalid {mode=} specified. Available modes: {weight_modes}"

# remove non-finite numbers that would ruin the statistic
image_data = image_data[np.isfinite(image_data)][::stride]

if np.all(~np.isfinite(image_data)):
return 0.0

if mode == "mad":
median = np.median(image_data)
mad = np.median(np.abs(image_data - median))
weight = 1.0 / mad**2
elif mode == "std":
std = np.std(image_data)
weight = 1.0 / std**2
else:
raise ValueError(f"Invalid {mode=} specified. Available modes: {weight_modes}")

float_weight = float(weight)
return float_weight if np.isfinite(float_weight) else 0.0


def get_image_weight(
image_path: Path, mode: str = "mad", stride: int = 1, image_slice: int = 0
) -> List[float]:
"""Compute an image weight supplied to linmos, which is used for optimally
weighting overlapping images. Supported modes are 'mad' and 'mtd', which
simply resolve to their numpy equivalents.
Expand All @@ -142,55 +248,57 @@ def get_image_weight(
the same way, it does not necessarily have to correspond to an optimatelly
calculated RMS.

The stride parameter will only include every N'th pixel when computing the
weights. A smaller set of pixels will reduce the time required to calculate
the weights, but may come at the cost of accuracy with large values.

Args:
image (Path): The path to the image fits file to inspect.
mode (str, optional): Which mode should be used when calculating the weight. Defaults to 'mad'.
stride (int, optional): Include every n'th pixel when computing the weight. '1' includes all pixels. Defaults to 1.
image_slice (int, optional): The image slice in the HDU list of the `image` fits file to inspect. Defaults to 0.

Raises:
ValueError: Raised when a mode is requested but does not exist

Returns:
float: The weight to supply to linmos
List[float]: The weight per channel to supply to linmos
"""

logger.debug(
f"Computing linmos weight using {mode=}, {image_slice=} for {image_path}. "
)
weight_modes = ("mad", "std")
assert (
mode in weight_modes
), f"Invalid {mode=} specified. Available modes: {weight_modes}"

weights: List[float] = []
with fits.open(image_path, memmap=True) as in_fits:
image_data = in_fits[image_slice].data # type: ignore

assert len(
image_data.shape
assert (
len(image_data.shape) >= 2
), f"{len(image_data.shape)=} is less than two. Is this really an image?"

# remove non-finite numbers
image_data = image_data[np.isfinite(image_data)]

logger.debug(f"Data shape is: {image_data.shape}")
if mode == "mad":
median = np.median(image_data)
mad = np.median(np.abs(image_data - median))
weight = 1.0 / mad**2
elif mode == "std":
std = np.std(image_data)
weight = 1.0 / std**2
else:
raise ValueError(
f"Invalid {mode=} specified. Available modes: {weight_modes}"
)
image_shape = image_data.shape[-2:]
image_data = (
image_data.reshape((-1, *image_shape))
if len(image_data.shape)
else image_data
)

assert (
len(image_data.shape) == 3
), f"Expected to have shape (chan, dec, ra), got {image_data.shape}"

for idx, chan_image_data in enumerate(image_data):
weight = _get_image_weight_plane(image_data=chan_image_data, stride=stride)
logger.info(f"Channel {idx} {weight=:.3f} for {image_path}")

weights.append(weight)

logger.info(f"Weight {weight:.3f} for {image_path}")
return float(weight)
return weights


def generate_weights_list_and_files(
image_paths: Collection[Path], mode: str = "mad"
image_paths: Collection[Path], mode: str = "mad", stride: int = 1
) -> str:
"""Generate the expected linmos weight files, and construct an appropriate
string that can be embedded into a linmos partset. These weights files will
Expand All @@ -207,6 +315,10 @@ def generate_weights_list_and_files(
the moment it is only intended to work on MFS images. It __is not__ currently
intended to be used on image cubes.

The stride parameter will only include every N'th pixel when computing the
weights. A smaller set of pixels will reduce the time required to calculate
the weights, but may come at the cost of accuracy with large values.

Args:
image_paths (Collection[Path]): Images to iterate over to create a corresponding weights.txt file.
mode (str, optional): The mode to use when calling get_image_weight
Expand All @@ -232,8 +344,12 @@ def generate_weights_list_and_files(
with open(weight_file, "w") as out_file:
logger.info(f"Writing {weight_file}")
out_file.write("#Channel Weight\n")
image_weight = get_image_weight(image_path=image, mode=mode)
out_file.write(f"0 {image_weight}\n")
image_weights = get_image_weight(image_path=image, mode=mode, stride=stride)
weights = "\n".join(
[f"{idx} {weight}" for idx, weight in enumerate(image_weights)]
)
out_file.write(weights)
out_file.write("\n") # Required for linmos to properly process weights

weight_str = [
str(weight_file) for weight_file in weight_file_list if weight_file.exists()
Expand Down Expand Up @@ -355,7 +471,9 @@ def generate_linmos_parameter_set(
# quality. In reality, this should be updated to provide a RMS noise
# estimate per-pixel of each image.
if weight_list is None:
weight_list = generate_weights_list_and_files(image_paths=images, mode="mad")
weight_list = generate_weights_list_and_files(
image_paths=images, mode="mad", stride=8
)

beam_order_strs = [str(extract_beam_from_name(str(p.name))) for p in images]
beam_order_list = "[" + ",".join(beam_order_strs) + "]"
Expand Down Expand Up @@ -446,7 +564,7 @@ def linmos_images(
linmos_parset.absolute().parent
]
if holofile:
bind_dirs.append(holofile.absolute())
bind_dirs.append(holofile.absolute().parent)

run_singularity_command(
image=container, command=linmos_cmd_str, bind_dirs=bind_dirs
Expand Down Expand Up @@ -501,6 +619,12 @@ def get_parser() -> ArgumentParser:
default=None,
help="Path to the holography FITS cube used for primary beam corrections",
)
parset_parser.add_argument(
"--pol-axis",
type=float,
default=2 * np.pi / 8,
help="The rotation in radians of the third-axis of the obseration. Defaults to PI/4",
)
parset_parser.add_argument(
"--yandasoft-container",
type=Path,
Expand All @@ -509,7 +633,7 @@ def get_parser() -> ArgumentParser:
)

trim_parser = subparsers.add_parser(
"trim", help="Generate a yandasoft linmos parset"
"trim", help="Remove blank border of FITS image"
)

trim_parser.add_argument(
Expand All @@ -532,6 +656,7 @@ def cli() -> None:
linmos_names=linmos_names,
weight_list=args.weight_list,
holofile=args.holofile,
pol_axis=args.pol_axis,
)
else:
linmos_images(
Expand All @@ -541,6 +666,7 @@ def cli() -> None:
weight_list=args.weight_list,
holofile=args.holofile,
container=args.yandasoft_container,
pol_axis=args.pol_axis,
)
elif args.mode == "trim":
images = args.images
Expand Down
Loading
Loading