diff --git a/.github/workflows/python-lint-test.yml b/.github/workflows/python-lint-test.yml index 739f0fae..25dc7014 100644 --- a/.github/workflows/python-lint-test.yml +++ b/.github/workflows/python-lint-test.yml @@ -18,10 +18,10 @@ jobs: steps: - uses: actions/checkout@v3 - - name: Set up Python 3.12 + - name: Set up Python 3.11 uses: actions/setup-python@v3 with: - python-version: "3.12" + python-version: "3.11" - name: Install casacore and boost run: | sudo apt install -y build-essential libcfitsio-dev liblapack-dev libboost-python-dev python3-dev wcslib-dev casacore-dev diff --git a/CHANGELOG.md b/CHANGELOG.md index 23b284c3..16f07f6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,11 @@ # dev +- added in convolving of cubes to common resolution across channels +- cubes are supported when computing the yandasoft linmos weights and trimming +- `--coadd-cubes` option added to co-add cubes on the final imaging round + together to form a single field spectral cube + # 0.2.6 - if `-temp-dir` used in wsclean then imaging products are produced here and diff --git a/flint/coadd/linmos.py b/flint/coadd/linmos.py index 50ac2199..30c4883b 100644 --- a/flint/coadd/linmos.py +++ b/flint/coadd/linmos.py @@ -2,7 +2,7 @@ from argparse import ArgumentParser from pathlib import Path -from typing import Collection, List, NamedTuple, Optional, Tuple +from typing import Collection, List, NamedTuple, Optional, Tuple, Literal import numpy as np from astropy.io import fits @@ -37,21 +37,21 @@ class BoundingBox(NamedTuple): ymax: int """Maximum y pixel""" original_shape: Tuple[int, int] - """The original shape of the image""" + """The original shape of the image. If constructed against a cube this is the shape of a single plane.""" -def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> BoundingBox: - """Construct a bounding box around finite pixels for a 2D image. This does not - support cube type images. - - If ``is_mask` is ``False``, the ``image_data`` will be masked internally using ``numpy.isfinite``. +def _create_bound_box_plane( + image_data: np.ndarray, is_masked: bool = False +) -> Optional[BoundingBox]: + """Create a bounding box around pixels in a 2D image. If all + pixels are not valid, then ``None`` is returned. Args: - image_data (np.ndarray): The image data that will have a bounding box constructed for. - is_masked (bool, optional): if this is ``True`` the ``image_data`` are treated as a boolean mask array. Defaults to False. + image_data (np.ndarray): The 2D ina==mage to construct a bounding box around + is_masked (bool, optional): Whether to treat the image as booleans or values. Defaults to False. Returns: - BoundingBox: The tight bounding box around pixels. + Optional[BoundingBox]: None if no valid pixels, a bounding box with the (xmin,xmax,ymin,ymax) of valid pixels """ assert ( len(image_data.shape) == 2 @@ -60,6 +60,10 @@ def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> Boundin # First convert to a boolean array image_valid = image_data if is_masked else np.isfinite(image_data) + if not any(image_valid.reshape(-1)): + logger.info("No pixels to creating bounding box for") + return None + # Then make them 1D arrays x_valid = np.any(image_valid, axis=1) y_valid = np.any(image_valid, axis=0) @@ -68,6 +72,62 @@ def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> Boundin xmin, xmax = np.where(x_valid)[0][[0, -1]] ymin, ymax = np.where(y_valid)[0][[0, -1]] + return BoundingBox( + xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, original_shape=image_data.shape[-2:] + ) + + +def create_bound_box(image_data: np.ndarray, is_masked: bool = False) -> BoundingBox: + """Construct a bounding box around finite pixels for a 2D image. + + If a cube ids provided, the bounding box is constructed from pixels + as broadcast across all of the non-spatial dimensions. That is to + say the single bounding box can be projected across all channel/stokes + channels + + If ``is_mask` is ``False``, the ``image_data`` will be masked internally using ``numpy.isfinite``. + + Args: + image_data (np.ndarray): The image data that will have a bounding box constructed for. + is_masked (bool, optional): if this is ``True`` the ``image_data`` are treated as a boolean mask array. Defaults to False. + + Returns: + BoundingBox: The tight bounding box around pixels. + """ + reshaped_image_data = image_data.reshape((-1, *image_data.shape[-2:])) + logger.info(f"New image shape {reshaped_image_data.shape} from {image_data.shape}") + + bounding_boxes = [ + _create_bound_box_plane(image_data=image, is_masked=is_masked) + for image in reshaped_image_data + ] + bounding_boxes = [bb for bb in bounding_boxes if bb is not None] + + if len(bounding_boxes) == 0: + logger.info("No valid bounding box found. Constructing one for all pixels") + return BoundingBox( + xmin=0, + xmax=image_data.shape[-1] - 1, + ymin=0, + ymax=image_data.shape[-2] - 1, + original_shape=tuple(image_data.shape[-2:]), # type: ignore + ) + elif len(bounding_boxes) == 1: + assert bounding_boxes[0] is not None, "This should not happen" + return bounding_boxes[0] + + assert all([bb is not None for bb in bounding_boxes]) + + logger.info( + f"Boounding boxes across {len(bounding_boxes)} constructed. Finsing limits. " + ) + # The type ignores below are to avoid mypy believe bound_boxes could + # include None. The above checks should be sufficient + xmin = min([bb.xmin for bb in bounding_boxes]) # type: ignore + xmax = max([bb.xmax for bb in bounding_boxes]) # type: ignore + ymin = min([bb.ymin for bb in bounding_boxes]) # type: ignore + ymax = max([bb.ymax for bb in bounding_boxes]) # type: ignore + return BoundingBox( xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, original_shape=image_data.shape ) @@ -100,15 +160,17 @@ def trim_fits_image( data = fits_image[0].data # type: ignore logger.info(f"Original data shape: {data.shape}") - image_shape = (data.shape[-2], data.shape[-1]) + image_shape = data.shape[-2:] logger.info(f"The image dimensions are: {image_shape}") if not bounding_box: + logger.info("Constructing a new bounding box") bounding_box = create_bound_box( image_data=np.squeeze(data), is_masked=False ) + logger.info(f"Constructed {bounding_box=}") else: - if image_shape != bounding_box.original_shape: + if image_shape != bounding_box.original_shape[-2:]: raise ValueError( f"Bounding box constructed against {bounding_box.original_shape}, but being applied to {image_shape=}" ) @@ -130,9 +192,53 @@ def trim_fits_image( return TrimImageResult(path=image_path, bounding_box=bounding_box) -def get_image_weight( - image_path: Path, mode: str = "mad", image_slice: int = 0 +def _get_image_weight_plane( + image_data: np.ndarray, mode: Literal["std", "mad"] = "mad", stride: int = 4 ) -> float: + """Extract the inverse variance weight for an input plane of data + + Modes are 'std' or 'mad'. + + Args: + image_data (np.ndarray): Data to consider + mode (str, optional): Statistic computation mode. Defaults to "mad". + stride (int, optional): Include every n'th pixel when computing the weight. '1' includes all pixels. Defaults to 1. + + Raises: + ValueError: Raised when modes unknown + + Returns: + float: The inverse variance weight computerd + """ + + weight_modes = ("mad", "std") + assert ( + mode in weight_modes + ), f"Invalid {mode=} specified. Available modes: {weight_modes}" + + # remove non-finite numbers that would ruin the statistic + image_data = image_data[np.isfinite(image_data)][::stride] + + if np.all(~np.isfinite(image_data)): + return 0.0 + + if mode == "mad": + median = np.median(image_data) + mad = np.median(np.abs(image_data - median)) + weight = 1.0 / mad**2 + elif mode == "std": + std = np.std(image_data) + weight = 1.0 / std**2 + else: + raise ValueError(f"Invalid {mode=} specified. Available modes: {weight_modes}") + + float_weight = float(weight) + return float_weight if np.isfinite(float_weight) else 0.0 + + +def get_image_weight( + image_path: Path, mode: str = "mad", stride: int = 1, image_slice: int = 0 +) -> List[float]: """Compute an image weight supplied to linmos, which is used for optimally weighting overlapping images. Supported modes are 'mad' and 'mtd', which simply resolve to their numpy equivalents. @@ -142,55 +248,57 @@ def get_image_weight( the same way, it does not necessarily have to correspond to an optimatelly calculated RMS. + The stride parameter will only include every N'th pixel when computing the + weights. A smaller set of pixels will reduce the time required to calculate + the weights, but may come at the cost of accuracy with large values. + Args: image (Path): The path to the image fits file to inspect. mode (str, optional): Which mode should be used when calculating the weight. Defaults to 'mad'. + stride (int, optional): Include every n'th pixel when computing the weight. '1' includes all pixels. Defaults to 1. image_slice (int, optional): The image slice in the HDU list of the `image` fits file to inspect. Defaults to 0. Raises: ValueError: Raised when a mode is requested but does not exist Returns: - float: The weight to supply to linmos + List[float]: The weight per channel to supply to linmos """ logger.debug( f"Computing linmos weight using {mode=}, {image_slice=} for {image_path}. " ) - weight_modes = ("mad", "std") - assert ( - mode in weight_modes - ), f"Invalid {mode=} specified. Available modes: {weight_modes}" + weights: List[float] = [] with fits.open(image_path, memmap=True) as in_fits: image_data = in_fits[image_slice].data # type: ignore - assert len( - image_data.shape + assert ( + len(image_data.shape) >= 2 ), f"{len(image_data.shape)=} is less than two. Is this really an image?" - # remove non-finite numbers - image_data = image_data[np.isfinite(image_data)] - - logger.debug(f"Data shape is: {image_data.shape}") - if mode == "mad": - median = np.median(image_data) - mad = np.median(np.abs(image_data - median)) - weight = 1.0 / mad**2 - elif mode == "std": - std = np.std(image_data) - weight = 1.0 / std**2 - else: - raise ValueError( - f"Invalid {mode=} specified. Available modes: {weight_modes}" - ) + image_shape = image_data.shape[-2:] + image_data = ( + image_data.reshape((-1, *image_shape)) + if len(image_data.shape) + else image_data + ) + + assert ( + len(image_data.shape) == 3 + ), f"Expected to have shape (chan, dec, ra), got {image_data.shape}" + + for idx, chan_image_data in enumerate(image_data): + weight = _get_image_weight_plane(image_data=chan_image_data, stride=stride) + logger.info(f"Channel {idx} {weight=:.3f} for {image_path}") + + weights.append(weight) - logger.info(f"Weight {weight:.3f} for {image_path}") - return float(weight) + return weights def generate_weights_list_and_files( - image_paths: Collection[Path], mode: str = "mad" + image_paths: Collection[Path], mode: str = "mad", stride: int = 1 ) -> str: """Generate the expected linmos weight files, and construct an appropriate string that can be embedded into a linmos partset. These weights files will @@ -207,6 +315,10 @@ def generate_weights_list_and_files( the moment it is only intended to work on MFS images. It __is not__ currently intended to be used on image cubes. + The stride parameter will only include every N'th pixel when computing the + weights. A smaller set of pixels will reduce the time required to calculate + the weights, but may come at the cost of accuracy with large values. + Args: image_paths (Collection[Path]): Images to iterate over to create a corresponding weights.txt file. mode (str, optional): The mode to use when calling get_image_weight @@ -232,8 +344,12 @@ def generate_weights_list_and_files( with open(weight_file, "w") as out_file: logger.info(f"Writing {weight_file}") out_file.write("#Channel Weight\n") - image_weight = get_image_weight(image_path=image, mode=mode) - out_file.write(f"0 {image_weight}\n") + image_weights = get_image_weight(image_path=image, mode=mode, stride=stride) + weights = "\n".join( + [f"{idx} {weight}" for idx, weight in enumerate(image_weights)] + ) + out_file.write(weights) + out_file.write("\n") # Required for linmos to properly process weights weight_str = [ str(weight_file) for weight_file in weight_file_list if weight_file.exists() @@ -355,7 +471,9 @@ def generate_linmos_parameter_set( # quality. In reality, this should be updated to provide a RMS noise # estimate per-pixel of each image. if weight_list is None: - weight_list = generate_weights_list_and_files(image_paths=images, mode="mad") + weight_list = generate_weights_list_and_files( + image_paths=images, mode="mad", stride=8 + ) beam_order_strs = [str(extract_beam_from_name(str(p.name))) for p in images] beam_order_list = "[" + ",".join(beam_order_strs) + "]" @@ -446,7 +564,7 @@ def linmos_images( linmos_parset.absolute().parent ] if holofile: - bind_dirs.append(holofile.absolute()) + bind_dirs.append(holofile.absolute().parent) run_singularity_command( image=container, command=linmos_cmd_str, bind_dirs=bind_dirs @@ -501,6 +619,12 @@ def get_parser() -> ArgumentParser: default=None, help="Path to the holography FITS cube used for primary beam corrections", ) + parset_parser.add_argument( + "--pol-axis", + type=float, + default=2 * np.pi / 8, + help="The rotation in radians of the third-axis of the obseration. Defaults to PI/4", + ) parset_parser.add_argument( "--yandasoft-container", type=Path, @@ -509,7 +633,7 @@ def get_parser() -> ArgumentParser: ) trim_parser = subparsers.add_parser( - "trim", help="Generate a yandasoft linmos parset" + "trim", help="Remove blank border of FITS image" ) trim_parser.add_argument( @@ -532,6 +656,7 @@ def cli() -> None: linmos_names=linmos_names, weight_list=args.weight_list, holofile=args.holofile, + pol_axis=args.pol_axis, ) else: linmos_images( @@ -541,6 +666,7 @@ def cli() -> None: weight_list=args.weight_list, holofile=args.holofile, container=args.yandasoft_container, + pol_axis=args.pol_axis, ) elif args.mode == "trim": images = args.images diff --git a/flint/convol.py b/flint/convol.py index 403a4c4c..02376941 100644 --- a/flint/convol.py +++ b/flint/convol.py @@ -10,9 +10,11 @@ from typing import Collection, List, NamedTuple, Optional import astropy.units as u +import numpy as np +from astropy.io import fits from astropy.wcs import FITSFixedWarning -from racs_tools import beamcon_2D -from radio_beam import Beam +from racs_tools import beamcon_2D, beamcon_3D +from radio_beam import Beam, Beams from flint.logging import logger @@ -53,6 +55,127 @@ def from_radio_beam(cls, radio_beam: Beam) -> BeamShape: ) +def check_if_cube_fits(fits_file: Path) -> bool: + """Check to see whether the data component of a FITS images is a cube. + Returns ``True`` is the data-shape needs 3-dimensions to be represented. + + Note: Unclear on usefulness + + Args: + fits_file (Path): FITS file that will be examinined + + Returns: + bool: Whether the input FITS file is a cube or not. + """ + + try: + squeeze_data = np.squeeze(fits.getdata(fits_file)) # type: ignore + except FileNotFoundError: + return False + except OSError: + return False + + return len(squeeze_data.shape) == 3 + + +def get_cube_common_beam( + cube_paths: Collection[Path], cutoff: Optional[float] = None +) -> List[BeamShape]: + """Given a set of input cube FITS files, compute a common beam + for each channel. + + Args: + cube_paths (Collection[Path]): Set of cube FITS files to inspect to derive a common beam + cutoff (Optional[float], optional): A cutoff value, in arcsec, that specifies the maximum BMAJ allowed. Defaults to None. + + Returns: + List[BeamShape]: List of target beam shapes to use, corresponding to each channel + """ + + _, common_beam_data_list = beamcon_3D.smooth_fits_cube( + infiles_list=list(cube_paths), + dryrun=True, + cutoff=cutoff, + mode="natural", + conv_mode="robust", + ncores=1, + ) + # Make proper check here that accounts for NaNs + for file in common_beam_data_list: + assert all( + (file[0].major == common_beam_data_list[0][0].major) + | np.isnan(file[0].major) + ) + assert all( + (file[0].minor == common_beam_data_list[0][0].minor) + | np.isnan(file[0].minor) + ) + assert all( + (file[0].pa == common_beam_data_list[0][0].pa) | np.isnan(file[0].pa) + ) + + first_cube_fits_beam = common_beam_data_list[0][0] + assert isinstance( + first_cube_fits_beam, Beams + ), f"Unexpected type for common beams. Expected Beams, got {type(first_cube_fits_beam)}" + + beam_shape_list = [ + BeamShape.from_radio_beam(radio_beam=beam) # type: ignore + for beam in first_cube_fits_beam + ] + return beam_shape_list + + +def convolve_cubes( + cube_paths: Collection[Path], + beam_shapes: List[BeamShape], + cutoff: Optional[float] = None, + convol_suffix: str = "conv", +) -> Collection[Path]: + logger.info(f"Will attempt to convol {len(cube_paths)} cubes") + if cutoff: + logger.info(f"Supplied cutoff {cutoff}") + + # Extractubg the beam properties + beam_major_list = [float(beam.bmaj_arcsec) for beam in beam_shapes] + beam_minor_list = [float(beam.bmin_arcsec) for beam in beam_shapes] + beam_pa_list = [float(beam.bpa_deg) for beam in beam_shapes] + + # Sanity test + assert len(beam_major_list) == len(beam_minor_list) == len(beam_pa_list) + + logger.info("Convoling cubes") + cube_data_list, _ = beamcon_3D.smooth_fits_cube( + infiles_list=list(cube_paths), + dryrun=False, + cutoff=cutoff, + mode="natural", + conv_mode="robust", + bmaj=beam_major_list, + bmin=beam_minor_list, + bpa=beam_pa_list, + suffix=convol_suffix, + ) + + # Construct the name of the new file created. For the moment this is done + # manually as it is not part of the returned object + # TODO: Extend the return struct from beamcon_3D to include output name + convol_cubes_path = [ + Path(cube_data.filename).with_suffix(f".{convol_suffix}.fits") + for cube_data in cube_data_list + ] + + # Show the mapping as a sanity check + for input_cube, output_cube in zip(list(cube_paths), convol_cubes_path): + logger.info(f"{input_cube=} convolved to {output_cube}") + + # Trust no one + assert all( + [p.exists() for p in convol_cubes_path] + ), "A convolved cube does not exist" + return convol_cubes_path + + def get_common_beam( image_paths: Collection[Path], cutoff: Optional[float] = None ) -> BeamShape: @@ -158,6 +281,12 @@ def get_parser() -> ArgumentParser: default="conv", help="The suffix added to convolved images. ", ) + convol_parser.add_argument( + "--cubes", + action="store_true", + default=False, + help="Treat the input files as cubes and use the corresponding 3D beam selection and convolution. ", + ) maxbeam_parser = subparsers.add_parser( "maxbeam", help="Find the optimal beam size for a set of images." @@ -176,6 +305,23 @@ def get_parser() -> ArgumentParser: help="Beams whose major-axis are larger then this (in arcseconds) are ignored from the calculation of the optimal beam.", ) + cubemaxbeams_parser = subparsers.add_parser( + "cubemaxbeam", + help="Calculate the set of common beams across channels in a set of cubes", + ) + cubemaxbeams_parser.add_argument( + "cubes", + type=Path, + nargs="+", + help="The images that will be convolved to a common resolution", + ) + cubemaxbeams_parser.add_argument( + "--cutoff", + type=float, + default=None, + help="Beams whose major-axis are larger then this (in arcseconds) are ignored from the calculation of the optimal beam.", + ) + return parser @@ -187,13 +333,38 @@ def cli() -> None: if args.mode == "maxbeam": get_common_beam(image_paths=args.images, cutoff=args.cutoff) if args.mode == "convol": - common_beam = get_common_beam(image_paths=args.images, cutoff=args.cutoff) - _ = convolve_images( - image_paths=args.images, - beam_shape=common_beam, - cutoff=args.cutoff, - convol_suffix=args.convol_suffix, + if args.cubes: + assert all( + [check_if_cube_fits(fits_file=f) for f in args.images] + ), "Not all input files are FITS cubes" + common_beams = get_cube_common_beam( + cube_paths=args.images, cutoff=args.cutoff + ) + for image in args.images: + logger.info(f"Convoling {image}") + _ = convolve_cubes( + cube_paths=[image], + beam_shapes=common_beams, + cutoff=args.cutoff, + convol_suffix=args.convol_suffix, + ) + + else: + assert not all( + [check_if_cube_fits(fits_file=f) for f in args.images] + ), "Not all input files are FITS images (not cubes)" + common_beam = get_common_beam(image_paths=args.images, cutoff=args.cutoff) + _ = convolve_images( + image_paths=args.images, + beam_shape=common_beam, + cutoff=args.cutoff, + convol_suffix=args.convol_suffix, + ) + if args.mode == "cubemaxbeam": + common_beam_shape_list = get_cube_common_beam( + cube_paths=args.cubes, cutoff=args.cutoff ) + logger.info(f"Extracted {common_beam_shape_list=}") if __name__ == "__main__": diff --git a/flint/data/tests/sub_cube_fits_examples.zip b/flint/data/tests/sub_cube_fits_examples.zip new file mode 100644 index 00000000..acd2baf3 Binary files /dev/null and b/flint/data/tests/sub_cube_fits_examples.zip differ diff --git a/flint/imager/wsclean.py b/flint/imager/wsclean.py index d8599dca..5b0d073c 100644 --- a/flint/imager/wsclean.py +++ b/flint/imager/wsclean.py @@ -610,11 +610,32 @@ def combine_subbands_to_cube( logger.info(f"Combining {len(subband_images)} images. {subband_images=}") hdu1, freqs = combine_fits(file_list=subband_images) + # This changes the output cube to a shape of (chan, pol, dec, ra) + # which is what yandasoft linmos tasks like + new_header = hdu1[0].header # type: ignore + data_cube = hdu1[0].data # type: ignore + + tmp_header = new_header.copy() + # Need to swap NAXIS 3 and 4 to make LINMOS happy - booo + for a, b in ((3, 4), (4, 3)): + new_header[f"CTYPE{a}"] = tmp_header[f"CTYPE{b}"] + new_header[f"CRPIX{a}"] = tmp_header[f"CRPIX{b}"] + new_header[f"CRVAL{a}"] = tmp_header[f"CRVAL{b}"] + new_header[f"CDELT{a}"] = tmp_header[f"CDELT{b}"] + new_header[f"CUNIT{a}"] = tmp_header[f"CUNIT{b}"] + + # Cube is currently STOKES, FREQ, RA, DEC - needs to be FREQ, STOKES, RA, DEC + data_cube = np.moveaxis(data_cube, 1, 0) + hdu1[0].data = data_cube # type: ignore + hdu1[0].header = new_header # type: ignore + output_cube_name = create_image_cube_name( image_prefix=Path(imageset.prefix), mode=mode ) + + # Write out the hdu to preserve the beam table constructed in fitscube logger.info(f"Writing {output_cube_name=}") - hdu1.writeto(output_cube_name, overwrite=True) + hdu1.writeto(output_cube_name) output_freqs_name = Path(output_cube_name).with_suffix(".freqs_Hz.txt") np.savetxt(output_freqs_name, freqs.to("Hz").value) diff --git a/flint/naming.py b/flint/naming.py index a59e96da..76a7bb5b 100644 --- a/flint/naming.py +++ b/flint/naming.py @@ -11,6 +11,24 @@ from flint.options import MS +def get_fits_cube_from_paths(paths: List[Path]) -> List[Path]: + """Given a list of files, find the ones that appear to be FITS files + and contain the ``.cube.`` field indicator. A regular expression searching + for both the ``.cube.`` and ``.fits`` file type is used. + + Args: + paths (List[Path]): The set of paths to examine to identify potential cube fits images from + + Returns: + List[Path]: Set of paths matching the search criteria + """ + cube_expression = re.compile(r"\.cube\..*fits$") + + cube_files = [path for path in paths if bool(cube_expression.search(str(path)))] + + return cube_files + + def create_image_cube_name( image_prefix: Path, mode: str, suffix: str = "cube.fits" ) -> Path: diff --git a/flint/options.py b/flint/options.py index e43d9940..d6534eb5 100644 --- a/flint/options.py +++ b/flint/options.py @@ -152,6 +152,8 @@ class FieldOptions(NamedTuple): """Rename MSs throughout rounds of imaging and self-cal instead of creating copies. This will delete data-columns throughout. """ stokes_v_imaging: bool = False """Specifies whether Stokes-V imaging will be carried out after the final round of imagine (whether or not self-calibration is enabled). """ + coadd_cubes: bool = False + """Co-add cubes formed throughout imaging together. Cubes will be smoothed channel-wise to a common resolution. Only performed on final set of images""" def with_options(self, **kwargs) -> FieldOptions: _dict = self._asdict() diff --git a/flint/prefect/common/imaging.py b/flint/prefect/common/imaging.py index 7dcd5f78..60359dd8 100644 --- a/flint/prefect/common/imaging.py +++ b/flint/prefect/common/imaging.py @@ -5,7 +5,7 @@ """ from pathlib import Path -from typing import Any, Collection, Dict, List, Optional, TypeVar, Union, Tuple +from typing import Any, Collection, Dict, List, Literal, Optional, TypeVar, Union, Tuple import pandas as pd from prefect import task, unmapped @@ -18,7 +18,13 @@ select_aosolution_for_ms, ) from flint.coadd.linmos import LinmosCommand, linmos_images -from flint.convol import BeamShape, convolve_images, get_common_beam +from flint.convol import ( + BeamShape, + convolve_images, + get_common_beam, + get_cube_common_beam, + convolve_cubes, +) from flint.flagging import flag_ms_aoflagger from flint.imager.wsclean import ( ImageSet, @@ -39,7 +45,12 @@ rename_column_in_ms, split_by_field, ) -from flint.naming import FITSMaskNames, get_beam_resolution_str, processed_ms_format +from flint.naming import ( + FITSMaskNames, + get_beam_resolution_str, + get_fits_cube_from_paths, + processed_ms_format, +) from flint.options import FieldOptions from flint.peel.potato import potato_peel from flint.prefect.common.utils import upload_image_as_artifact @@ -384,6 +395,95 @@ def task_get_common_beam( return beam_shape +@task +def task_get_cube_common_beam( + wsclean_cmds: Collection[WSCleanCommand], + cutoff: float = 25, +) -> List[BeamShape]: + """Compute a common beam size for input cubes. + + Args: + wsclean_cmds (Collection[WSCleanCommand]): Input images whose restoring beam properties will be considered + cutoff (float, optional): Major axis larger than this valur, in arcseconds, will be ignored. Defaults to 25. + + Returns: + List[BeamShape]: The final convolving beam size to be used per channel in cubes + """ + + images_to_consider: List[Path] = [] + + # TODO: This should support other image types + for wsclean_cmd in wsclean_cmds: + if wsclean_cmd.imageset is None: + logger.warning( + f"No imageset for {wsclean_cmd.ms} found. Has imager finished?" + ) + continue + images_to_consider.extend(wsclean_cmd.imageset.image) + + images_to_consider = get_fits_cube_from_paths(paths=images_to_consider) + + logger.info( + f"Considering {len(images_to_consider)} images across {len(wsclean_cmds)} outputs. " + ) + + beam_shapes = get_cube_common_beam(cube_paths=images_to_consider, cutoff=cutoff) + + return beam_shapes + + +@task +def task_convolve_cube( + wsclean_cmd: WSCleanCommand, + beam_shapes: List[BeamShape], + cutoff: float = 60, + mode: Literal["image"] = "image", + convol_suffix_str: str = "conv", +) -> Collection[Path]: + """Convolve images to a specified resolution + + Args: + wsclean_cmd (WSCleanCommand): Collection of output images from wsclean that will be convolved + beam_shapes (BeamShape): The shape images will be convolved to + cutoff (float, optional): Maximum major beam axis an image is allowed to have before it will not be convolved. Defaults to 60. + convol_suffix_str (str, optional): The suffix added to the convolved images. Defaults to 'conv'. + + Returns: + Collection[Path]: Path to the output images that have been convolved. + """ + assert ( + wsclean_cmd.imageset is not None + ), f"{wsclean_cmd.ms} has no attached imageset." + + supported_modes = ("image",) + logger.info(f"Extracting {mode}") + if mode == "image": + image_paths = list(wsclean_cmd.imageset.image) + else: + raise ValueError(f"{mode=} is not supported. Known modes are {supported_modes}") + + logger.info(f"Extracting cubes from imageset {mode=}") + image_paths = get_fits_cube_from_paths(paths=image_paths) + + # It is possible depending on how aggressively cleaning image products are deleted that these + # some cleaning products (e.g. residuals) do not exist. There are a number of ways one could consider + # handling this. The pirate in me feels like less is more, so an error will be enough. Keeping + # things simple and avoiding the problem is probably the better way of dealing with this + # situation. In time this would mean that we inspect and handle conflicting pipeline options. + assert ( + image_paths is not None + ), f"{image_paths=} for {mode=} and {wsclean_cmd.imageset=}" + + logger.info(f"Will convolve {image_paths}") + + return convolve_cubes( + cube_paths=image_paths, + beam_shapes=beam_shapes, + cutoff=cutoff, + convol_suffix=convol_suffix_str, + ) + + @task def task_convolve_image( wsclean_cmd: WSCleanCommand, @@ -466,7 +566,7 @@ def task_convolve_image( def task_linmos_images( images: Collection[Collection[Path]], container: Path, - filter: str = ".MFS.", + filter: Optional[str] = ".MFS.", field_name: Optional[str] = None, suffix_str: str = "noselfcal", holofile: Optional[Path] = None, @@ -480,7 +580,7 @@ def task_linmos_images( Args: images (Collection[Collection[Path]]): Images that will be co-added together container (Path): Path to singularity container that contains yandasoft - filter (str, optional): Filter to extract the images that will be extracted from the set of input images. These will be co-added. Defaults to ".MFS.". + filter (Optional[str], optional): Filter to extract the images that will be extracted from the set of input images. These will be co-added. If None all images are co-aded. Defaults to ".MFS.". field_name (Optional[str], optional): Name of the field, which is included in the output images created. Defaults to None. suffix_str (str, optional): Additional string added to the prefix of the output linmos image products. Defaults to "noselfcal". holofile (Optional[Path], optional): The FITS cube with the beam corrections derived from ASKAP holography. Defaults to None. @@ -495,10 +595,15 @@ def task_linmos_images( # TODO: Need to flatten images # TODO: Need a better way of getting field names + # TODO: Need a better filter approach. Would probably be better to + # have literals for the type of product (MFS, cube, model) to be + # sure of appropriate extraction all_images = [img for beam_images in images for img in beam_images] logger.info(f"Number of images to examine {len(all_images)}") - filter_images = [img for img in all_images if filter in str(img)] + filter_images = ( + [img for img in all_images if filter in str(img)] if filter else all_images + ) logger.info(f"Number of filtered images to linmos: {len(filter_images)}") candidate_image = filter_images[0] @@ -551,7 +656,6 @@ def _convolve_linmos( beam_shape: BeamShape, field_options: FieldOptions, linmos_suffix_str: str, - cutoff: float = 0.05, field_summary: Optional[FieldSummary] = None, convol_mode: str = "image", convol_filter: str = ".MFS.", @@ -565,7 +669,6 @@ def _convolve_linmos( beam_shape (BeamShape): The beam shape that residual images will be convolved to field_options (FieldOptions): Options related to the processing of the field linmos_suffix_str (str): The suffix string passed to the linmos parset name - cutoff (float, optional): The primary beam attenuation cutoff supplied to linmos when coadding. Defaults to 0.05. field_summary (Optional[FieldSummary], optional): The summary of the field, including (importantly) to orientation of the third-axis. Defaults to None. convol_mode (str, optional): The mode passed to the convol task to describe the images to extract. Support image or residual. Defaults to image. convol_filter (str, optional): A text file applied when assessing images to co-add. Defaults to '.MFS.'. @@ -577,20 +680,21 @@ def _convolve_linmos( conv_images = task_convolve_image.map( wsclean_cmd=wsclean_cmds, - beam_shape=unmapped(beam_shape), - cutoff=150.0, + beam_shape=unmapped(beam_shape), # type: ignore + cutoff=field_options.beam_cutoff, mode=convol_mode, filter=convol_filter, convol_suffix_str=convol_suffix_str, ) + assert field_options.yandasoft_container is not None parset = task_linmos_images.submit( - images=conv_images, + images=conv_images, # type: ignore container=field_options.yandasoft_container, suffix_str=linmos_suffix_str, holofile=field_options.holofile, - cutoff=cutoff, + cutoff=field_options.pb_cutoff, field_summary=field_summary, - ) + ) # type: ignore return parset @@ -650,10 +754,9 @@ def _create_convol_linmos_images( parsets.append( _convolve_linmos( wsclean_cmds=wsclean_cmds, - beam_shape=beam_shape, + beam_shape=beam_shape, # type: ignore field_options=field_options, linmos_suffix_str=f"{linmos_suffix_str}.residual", - cutoff=field_options.pb_cutoff, field_summary=field_summary, convol_mode="residual", convol_filter=".MFS.", @@ -663,10 +766,9 @@ def _create_convol_linmos_images( parsets.append( _convolve_linmos( wsclean_cmds=wsclean_cmds, - beam_shape=beam_shape, + beam_shape=beam_shape, # type: ignore field_options=field_options, linmos_suffix_str=linmos_suffix_str, - cutoff=field_options.pb_cutoff, field_summary=field_summary, convol_mode="image", convol_filter=".MFS.", @@ -677,6 +779,39 @@ def _create_convol_linmos_images( return parsets +def _create_convolve_linmos_cubes( + wsclean_cmds: Collection[WSCleanCommand], + field_options: FieldOptions, + current_round: Optional[int] = None, + additional_linmos_suffix_str: Optional[str] = "cube", +): + suffixes = [f"round{current_round}" if current_round else "noselfcal"] + if additional_linmos_suffix_str: + suffixes.insert(0, additional_linmos_suffix_str) + linmos_suffix_str = ".".join(suffixes) + + beam_shapes = task_get_cube_common_beam.submit( + wsclean_cmds=wsclean_cmds, cutoff=field_options.beam_cutoff + ) + convolved_cubes = task_convolve_cube.map( + wsclean_cmd=wsclean_cmds, # type: ignore + cutoff=field_options.beam_cutoff, + mode=unmapped("image"), # type: ignore + beam_shapes=unmapped(beam_shapes), # type: ignore + ) + + assert field_options.yandasoft_container is not None + parset = task_linmos_images.submit( + images=convolved_cubes, # type: ignore + container=field_options.yandasoft_container, + suffix_str=linmos_suffix_str, + holofile=field_options.holofile, + cutoff=field_options.pb_cutoff, + filter=None, + ) + return parset + + @task def task_create_image_mask_model( image: Union[LinmosCommand, ImageSet, WSCleanCommand], @@ -841,7 +976,7 @@ def task_create_validation_tables( create_table_artifact( table=df_dict, description=f"{table.stem}", - ) + ) # type: ignore elif isinstance(table, XMatchTables): for subtable in table: if subtable is None: @@ -853,7 +988,7 @@ def task_create_validation_tables( create_table_artifact( table=df_dict, description=f"{subtable.stem}", - ) + ) # type: ignore return validation_tables diff --git a/flint/prefect/flows/continuum_pipeline.py b/flint/prefect/flows/continuum_pipeline.py index 59c3e65a..2b31a4fe 100644 --- a/flint/prefect/flows/continuum_pipeline.py +++ b/flint/prefect/flows/continuum_pipeline.py @@ -33,6 +33,7 @@ from flint.prefect.clusters import get_dask_runner from flint.prefect.common.imaging import ( _create_convol_linmos_images, + _create_convolve_linmos_cubes, _validation_items, task_copy_and_preprocess_casda_askap_ms, task_create_apply_solutions_cmd, @@ -63,6 +64,13 @@ def _check_field_options(field_options: FieldOptions) -> None: run_aegean = ( False if field_options.aegean_container is None else field_options.run_aegean ) + if ( + field_options.imaging_strategy is not None + and not field_options.imaging_strategy.exists() + ): + raise ValueError( + f"Imagign strategy file {field_options.imaging_strategy} is set, but the path does not exist" + ) if field_options.use_beam_masks is True and run_aegean is False: raise ValueError( "run_aegean and aegean container both need to be set is beam masks is being used. " @@ -79,6 +87,14 @@ def _check_field_options(field_options: FieldOptions) -> None: raise ValueError( "CASA Container needs to be set if self-calibraiton is to be performed" ) + if field_options.coadd_cubes: + if ( + field_options.yandasoft_container is None + or not field_options.yandasoft_container + ): + raise ValueError( + "Unable to create linmos cubes without a yandasoft container" + ) def _check_create_output_split_science_path( @@ -254,7 +270,7 @@ def process_science_fields( mss=preprocess_science_mss, cal_sbid_path=bandpass_path, holography_path=field_options.holofile, - ) + ) # type: ignore logger.info(f"{field_summary=}") if field_options.wsclean_container is None: @@ -312,12 +328,12 @@ def process_science_fields( if run_aegean: aegean_field_output = task_run_bane_and_aegean.submit( image=parset, aegean_container=unmapped(field_options.aegean_container) - ) + ) # type: ignore field_summary = task_update_field_summary.submit( field_summary=field_summary, aegean_outputs=aegean_field_output, linmos_command=parset, - ) + ) # type: ignore archive_wait_for.append(field_summary) if run_validation and field_options.reference_catalogue_directory: @@ -358,7 +374,7 @@ def process_science_fields( wait_for=[ field_summary ], # To make sure field summary is created with unzipped MSs - ) + ) # type: ignore stokes_v_mss = cal_mss fits_beam_masks = None @@ -420,12 +436,12 @@ def process_science_fields( aegean_outputs = task_run_bane_and_aegean.submit( image=parsets_self[-1], aegean_container=unmapped(field_options.aegean_container), - ) + ) # type: ignore field_summary = task_update_field_summary.submit( field_summary=field_summary, aegean_outputs=aegean_outputs, round=current_round, - ) + ) # type: ignore if run_validation: assert field_options.reference_catalogue_directory, f"Reference catalogue directory should be set when {run_validation=}" val_results = _validation_items( @@ -435,6 +451,16 @@ def process_science_fields( ) archive_wait_for.append(val_results) + if field_options.coadd_cubes: + with tags("cubes"): + cube_parset = _create_convolve_linmos_cubes( + wsclean_cmds=wsclean_cmds, # type: ignore + field_options=field_options, + current_round=(field_options.rounds if field_options.rounds else None), + additional_linmos_suffix_str="cube", + ) + archive_wait_for.append(cube_parset) + if field_options.stokes_v_imaging: with tags("stokes-v"): stokes_v_wsclean_options = get_options_from_strategy( @@ -446,7 +472,7 @@ def process_science_fields( update_wsclean_options=unmapped(stokes_v_wsclean_options), fits_mask=fits_beam_masks, wait_for=wsclean_cmds, # Ensure that measurement sets are doubled up during imaging - ) + ) # type: ignore if field_options.yandasoft_container: parsets = _create_convol_linmos_images( wsclean_cmds=wsclean_cmds, @@ -474,7 +500,7 @@ def process_science_fields( max_round=field_options.rounds if field_options.rounds else None, update_archive_options=update_archive_options, wait_for=archive_wait_for, - ) + ) # type: ignore def setup_run_process_science_field( @@ -718,6 +744,12 @@ def get_parser() -> ArgumentParser: action="store_true", default=False, ) + parser.add_argument( + "--coadd-cubes", + default=False, + action="store_true", + help="Co-add cubes formed throughout imaging together. Cubes will be smoothed channel-wise to a common resolution. Only performed on final set of images", + ) return parser @@ -764,6 +796,7 @@ def cli() -> None: sbid_copy_path=args.sbid_copy_path, rename_ms=args.rename_ms, stokes_v_imaging=args.stokes_v_imaging, + coadd_cubes=args.coadd_cubes, ) setup_run_process_science_field( diff --git a/tests/test_convol.py b/tests/test_convol.py new file mode 100644 index 00000000..ba5bbdaa --- /dev/null +++ b/tests/test_convol.py @@ -0,0 +1,75 @@ +"""Bits around testing the convolution utilities""" + +import pytest +import shutil +from pathlib import Path + +import numpy as np +from astropy.io import fits + +from flint.convol import ( + check_if_cube_fits, + get_cube_common_beam, + BeamShape, +) +from flint.utils import get_packaged_resource_path + + +@pytest.fixture +def image_fits() -> Path: + image = Path( + get_packaged_resource_path( + package="flint.data.tests", + filename="SB39400.RACS_0635-31.beam0-MFS-subimage_rms.fits", + ) + ) + + return image + + +@pytest.fixture +def cube_fits(tmpdir) -> Path: + tmp_dir = Path(tmpdir) + cube_dir = Path(tmp_dir / "cubes") + cube_dir.mkdir(parents=True, exist_ok=True) + + cubes_zip = Path( + get_packaged_resource_path( + package="flint.data.tests", filename="sub_cube_fits_examples.zip" + ) + ) + assert cubes_zip.exists() + shutil.unpack_archive(cubes_zip, cube_dir) + + return cube_dir + + +def test_check_if_cube_fits(cube_fits, image_fits): + """See if the cube fits checker is picking up cubes with 3 axis""" + fits_files = list(cube_fits.glob("*sub.fits")) + assert len(fits_files) == 10 + assert all([check_if_cube_fits(fits_file=f) for f in fits_files]) + + assert not check_if_cube_fits(fits_file=image_fits) + assert not check_if_cube_fits(fits_file=Path("ThisDoesNotExist")) + + +def test_get_cube_common_beam_and_convol_cubes(cube_fits) -> None: + """Ensure that the common beam functionality of from beamcon_3D. Also test the + convolution to the cubes, as the initial compute can be expensive""" + fits_files = list(cube_fits.glob("*sub.fits")) + assert len(fits_files) == 10 + + data = fits.getdata(fits_files[0]) + data_shape = np.squeeze(data).shape # type: ignore + + beam_list = get_cube_common_beam(cube_paths=fits_files, cutoff=150.0) + assert len(beam_list) == data_shape[0] + assert all([isinstance(b, BeamShape) for b in beam_list]) + + # This appears to make pytest lock up + # cube_paths = convolve_cubes( + # cube_paths=fits_files, beam_shapes=beam_list, cutoff=150.0 + # ) + # assert all([isinstance(p, Path) for p in cube_paths]) + # assert all([p.exists() for p in cube_paths]) diff --git a/tests/test_linmos_coadd.py b/tests/test_linmos_coadd.py index a3968a0d..c62cc598 100644 --- a/tests/test_linmos_coadd.py +++ b/tests/test_linmos_coadd.py @@ -11,13 +11,38 @@ from flint.coadd.linmos import ( BoundingBox, + _create_bound_box_plane, _get_alpha_linmos_option, _get_holography_linmos_options, + _get_image_weight_plane, create_bound_box, + generate_weights_list_and_files, trim_fits_image, ) +def test_get_image_weight_plane(): + """The extraction of weights per plane""" + data = np.arange(100).reshape((10, 10)) + + with pytest.raises(AssertionError): + _get_image_weight_plane(image_data=data, mode="noexists") # type: ignore + + assert np.isclose( + 0.0016, + _get_image_weight_plane(image_data=data, mode="mad", stride=1), + atol=0.0001, + ) + assert np.isclose( + 0.00120012, + _get_image_weight_plane(image_data=data, mode="std", stride=1), + atol=0.0001, + ) + + data = np.arange(100).reshape((10, 10)) * np.nan + assert _get_image_weight_plane(image_data=data) == 0.0 + + def create_fits_image(out_path, image_size=(1000, 1000)): data = np.zeros(image_size) data[10:600, 20:500] = 1 @@ -28,6 +53,13 @@ def create_fits_image(out_path, image_size=(1000, 1000)): fits.writeto(out_path, data=data, header=header) +def create_image_cube(out_path): + data = np.arange(20 * 100).reshape((20, 10, 10)) + header = fits.header.Header({"CRPIX1": 10, "CRPIX2": 20, "CRPIX3": 1}) + + fits.writeto(out_path, header=header, data=data) + + def test_linmos_alpha_option(): """Ensure the rotation string supplied to linmos is calculated appropriately""" @@ -42,6 +74,40 @@ def test_linmos_alpha_option(): _get_alpha_linmos_option(pol_axis=1234) +def test_get_image_weights(tmpdir): + """See whether the weights computed per plane in a cube work appropriately""" + cube_weight = Path(tmpdir) / "cubeweight" + cube_weight.mkdir(parents=True, exist_ok=True) + cube_fits = cube_weight / "cube.fits" + + create_image_cube(out_path=cube_fits) + weight_file = cube_fits.with_suffix(".weights.txt") + assert not weight_file.exists() + + generate_weights_list_and_files(image_paths=[cube_fits], mode="mad") + assert weight_file.exists() + # The file must end with a newline for linmos to work + lines = weight_file.read_text().split("\n") + assert len(lines) == 22, f"{lines}" + + +def test_get_image_weight_with_strides(tmpdir): + """See whether the weights computed per plane in a cube work appropriately when striding over data""" + cube_weight = Path(tmpdir) / "cubeweight" + cube_weight.mkdir(parents=True, exist_ok=True) + cube_fits = cube_weight / "cube.fits" + + create_image_cube(out_path=cube_fits) + weight_file = cube_fits.with_suffix(".weights.txt") + assert not weight_file.exists() + + generate_weights_list_and_files(image_paths=[cube_fits], mode="mad", stride=10) + assert weight_file.exists() + # The file must end with a newline for linmos to work + lines = weight_file.read_text().split("\n") + assert len(lines) == 22, f"{lines}" + + def test_linmos_holo_options(tmpdir): holofile = Path(tmpdir) / "testholooptions/holo_file.fits" holofile.parent.mkdir(parents=True, exist_ok=True) @@ -87,6 +153,34 @@ def test_trim_fits(tmp_path): assert trim_data.shape == (589, 479) +def test_trim_fits_cube(tmp_path): + """Ensure that fits files that has cube can be trimmed appropriately based on row/columns with valid pixels""" + tmp_dir = tmp_path / "cube" + tmp_dir.mkdir() + + out_fits = tmp_dir / "example.fits" + + cube_size = (12, 1000, 1000) + data = np.zeros(cube_size) + data[3, 10:600, 20:500] = 1 + data[data == 0] = np.nan + + header = fits.header.Header({"CRPIX1": 10, "CRPIX2": 20}) + + fits.writeto(out_fits, data=data, header=header) + + og_hdr = fits.getheader(out_fits) + assert og_hdr["CRPIX1"] == 10 + assert og_hdr["CRPIX2"] == 20 + + trim_fits_image(out_fits) + trim_hdr = fits.getheader(out_fits) + trim_data = fits.getdata(out_fits) + assert trim_hdr["CRPIX1"] == -10 + assert trim_hdr["CRPIX2"] == 10 + assert trim_data.shape == (12, 589, 479) # type: ignore + + def test_trim_fits_image_matching(tmp_path): """See the the bounding box can be passed through for matching to cutout""" @@ -133,14 +227,57 @@ def test_bounding_box(): assert bb.ymax == 499 # slices upper limit no inclusive +def test_bounding_box_none(): + """Return None if there are no valid pixels to create a bounding box around""" + data = np.zeros((1000, 1000)) * np.nan + + bb = _create_bound_box_plane(image_data=data) + assert bb is None + + bb = create_bound_box(image_data=data) + assert isinstance(bb, BoundingBox) + assert bb.xmin == 0 + assert bb.xmin == 0 + assert bb.xmax == 999 + assert bb.ymax == 999 + + def test_bounding_box_cube(): - """Cube cut bounding boxes. Currently not supported.""" + """Cube cut bounding boxes.""" data = np.zeros((3, 1000, 1000)) data[:, 10:600, 20:500] = 1 data[data == 0] = np.nan with pytest.raises(AssertionError): - create_bound_box(image_data=data) + _create_bound_box_plane(image_data=data) + + bb = create_bound_box(image_data=data) + assert isinstance(bb, BoundingBox) + assert bb.xmin == 10 + assert bb.xmax == 599 + assert bb.ymin == 20 + assert bb.ymax == 499 + + +def test_bounding_box_cube_different_bounds(): + """Cube cut bounding boxes, where the largest bounding box that + captures all valid pixels""" + data = np.zeros((3, 1000, 1000)) + data[0, 10:600, 20:500] = 1 + data[1, 100:200, 600:800] = 1 + data[2, 800:888, 20:500] = 1 + + data[data == 0] = np.nan + + with pytest.raises(AssertionError): + _create_bound_box_plane(image_data=data) + + bb = create_bound_box(image_data=data) + assert isinstance(bb, BoundingBox) + assert bb.xmin == 10 + assert bb.xmax == 887 + assert bb.ymin == 20 + assert bb.ymax == 799 def test_bounding_box_with_mask(): diff --git a/tests/test_naming.py b/tests/test_naming.py index fa767c95..948792ce 100644 --- a/tests/test_naming.py +++ b/tests/test_naming.py @@ -19,6 +19,7 @@ extract_components_from_name, get_aocalibrate_output_path, get_beam_resolution_str, + get_fits_cube_from_paths, get_potato_output_base_path, get_sbid_from_path, get_selfcal_ms_name, @@ -27,6 +28,27 @@ ) +def test_get_cube_fits_from_paths(): + """Identify the files that contain the cube field and are fits""" + files = [ + "SB63789.EMU_1743-51.beam03.round4.i.image.cube.fits", + "SB63789.EMU_1743-51.beam03.round4.i.image.cube.other.fields.fits", + "SB63789.EMU_1743-51.beam03.round4.i.MFS.image.optimal.conv.fits", + "SB63789.EMU_1743-51.beam03.round4.i.MFS.residual.optimal.conv.fits", + "SB63789.EMU_1743-51.beam03.round4.i.MFS.image.fits", + "SB63789.EMU_1743-51.beam03.round4.i.MFS.residual.fits", + ] + files = [Path(f) for f in files] + + cube_files = get_fits_cube_from_paths(paths=files) + + assert len(cube_files) == 2 + assert cube_files[0] == Path("SB63789.EMU_1743-51.beam03.round4.i.image.cube.fits") + assert cube_files[1] == Path( + "SB63789.EMU_1743-51.beam03.round4.i.image.cube.other.fields.fits" + ) + + def test_create_image_cube_name(): """Put together a consistent file cube name""" name = create_image_cube_name( diff --git a/tests/test_wsclean.py b/tests/test_wsclean.py index 9f091248..7a2ace49 100644 --- a/tests/test_wsclean.py +++ b/tests/test_wsclean.py @@ -193,10 +193,33 @@ def test_combine_subbands_to_cube(tmpdir): with pytest.raises(TypeError): _ = combine_subbands_to_cube(imageset=files, remove_original_images=False) # type: ignore + +def test_combine_subbands_to_cube2(tmpdir): + """Load in example fits images to combine into a cube without deleting original""" + files = [ + get_packaged_resource_path( + package="flint.data.tests", + filename=f"SB56659.RACS_0940-04.beam17.round3-000{i}-image.sub.fits", + ) + for i in range(3) + ] + files = [Path(shutil.copy(Path(f), Path(tmpdir))) for f in files] + + assert len(files) == 3 + assert all([f.exists() for f in files]) + file_parent = files[0].parent + prefix = f"{file_parent}/SB56659.RACS_0940-04.beam17.round3" + imageset = ImageSet( + prefix=prefix, + image=files, + ) + new_imageset = combine_subbands_to_cube( imageset=imageset, remove_original_images=True ) assert all([not file.exists() for file in files]) + assert new_imageset.prefix == imageset.prefix + assert len(new_imageset.image) == 1 def test_resolve_key_value_to_cli():