Skip to content

Commit

Permalink
Merge pull request #67 from AlecThomson/validate
Browse files Browse the repository at this point in the history
Validate
  • Loading branch information
AlecThomson authored May 7, 2024
2 parents 1af1389 + 9b39dc2 commit 6999f2e
Show file tree
Hide file tree
Showing 19 changed files with 817 additions and 179 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### What's Changed

* Add validation stage to pipeline, including plot artifacts
* Add LINMOS rotation for rotated fields
* Fixes WSClean argument handling and clean thresholds


## [2.2.2] - 2024-04-18
### What's Changed
* [pre-commit.ci] pre-commit autoupdate by @pre-commit-ci in https://github.com/AlecThomson/arrakis/pull/61
Expand Down
4 changes: 2 additions & 2 deletions arrakis/configs/petrichor.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Set up for Petrichor
cluster_class: "dask_jobqueue.SLURMCluster"
cluster_kwargs:
cores: 20
cores: 32
processes: 1
name: 'spice-worker'
memory: "160GiB"
memory: "256GiB"
account: 'OD-217087'
#queue: 'workq'
walltime: '0-12:00:00'
Expand Down
3 changes: 2 additions & 1 deletion arrakis/cutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from pathlib import Path
from pprint import pformat
from shutil import copyfile
from typing import List, Optional, Set, TypeVar
from typing import List
from typing import NamedTuple as Struct
from typing import Optional, Set, TypeVar

import astropy.units as u
import numpy as np
Expand Down
3 changes: 2 additions & 1 deletion arrakis/frion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import os
from pathlib import Path
from pprint import pformat
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List
from typing import NamedTuple as Struct
from typing import Optional, Union

import astropy.units as u
import numpy as np
Expand Down
98 changes: 83 additions & 15 deletions arrakis/imager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
from glob import glob
from pathlib import Path
from subprocess import CalledProcessError
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List
from typing import NamedTuple as Struct
from typing import Optional, Tuple, Union

from arrakis.utils.meta import my_ceil
import astropy.units as u
import numpy as np
from astropy.io import fits
from astropy.stats import mad_std
Expand All @@ -31,6 +34,7 @@
beam_from_ms,
field_idx_from_ms,
field_name_from_ms,
get_pol_axis,
wsclean,
)
from arrakis.utils.pipeline import logo_str, workdir_arg_parser
Expand Down Expand Up @@ -161,6 +165,8 @@ def image_beam(
no_mf_weighting: bool = False,
no_update_model_required: bool = True,
beam_fitting_size: Optional[float] = 1.25,
disable_pol_local_rms: bool = False,
disable_pol_force_mask_rounds: bool = False,
) -> ImageSet:
"""Image a single beam"""
logger = get_run_logger()
Expand All @@ -178,6 +184,23 @@ def image_beam(

temp_dir_wsclean = parse_env_path(temp_dir_wsclean)

# Catch mis-matched args
if not local_rms:
logger.warning(
f"Local RMS is disabled. Setting local_rms_window to None. Was set to {local_rms_window}."
)
local_rms_window = None

if not multiscale:
logger.warning(
f"Multiscale is disabled. Setting multiscale_scale_bias to None. Was set to {multiscale_scale_bias}."
)
multiscale_scale_bias = None
logger.warning(
f"Multiscale is disabled. Setting multiscale_scales to None. Was set to {multiscale_scales}."
)
multiscale_scales = None

commands = []
# Do any I cleaning separately
do_stokes_I = "I" in pols
Expand All @@ -194,6 +217,7 @@ def image_beam(
pol="I",
verbose=True,
channels_out=nchan,
parallel_gridding=nchan,
scale=f"{scale}asec",
size=f"{npix} {npix}",
join_polarizations=False, # Only do I
Expand Down Expand Up @@ -229,18 +253,20 @@ def image_beam(

if all([p in pols.upper() for p in ("Q", "U")]):
if squared_channel_joining:
logger.info("Using squared channel joining")
logger.info("Reducing mask by sqrt(2) to account for this")
auto_mask_reduce = np.round(auto_mask / (np.sqrt(2)), decimals=2)
logger.info(
"Squared channel joining is enabled - scaling auto_mask and auto_threshold by power of 2"
)
auto_mask = my_ceil(auto_mask**2, 2)
auto_threshold = my_ceil(auto_threshold**2, 2)

logger.info(f"auto_mask = {auto_mask}")
logger.info(f"auto_mask_reduce = {auto_mask_reduce}")
else:
auto_mask_reduce = auto_mask
if disable_pol_local_rms:
logger.info("Disabling local RMS for polarisation images")
local_rms = False
local_rms_window = None

if local_rms_window:
local_rms_window = int(local_rms_window / 2)
logger.info(f"Scaled local RMS window to {local_rms_window}.")
if disable_pol_force_mask_rounds:
logger.info("Disabling force mask rounds for polarisation images")
force_mask_rounds = None

command = wsclean(
mslist=[ms.resolve(strict=True).as_posix()],
Expand All @@ -254,14 +280,15 @@ def image_beam(
pol=pols,
verbose=True,
channels_out=nchan,
parallel_gridding=nchan,
scale=f"{scale}asec",
size=f"{npix} {npix}",
join_polarizations=join_polarizations,
join_channels=join_channels,
squared_channel_joining=squared_channel_joining,
mgain=mgain,
niter=niter,
auto_mask=auto_mask_reduce,
auto_mask=auto_mask,
force_mask_rounds=force_mask_rounds,
auto_threshold=auto_threshold,
gridder=gridder,
Expand All @@ -276,9 +303,14 @@ def image_beam(
nmiter=nmiter,
local_rms=local_rms,
local_rms_window=local_rms_window,
# Avoid multiscale when using squared channel joining
multiscale=multiscale if not squared_channel_joining else False,
multiscale_scale_bias=multiscale_scale_bias,
multiscale_scales=multiscale_scales,
multiscale_scale_bias=multiscale_scale_bias
if not squared_channel_joining
else None,
multiscale_scales=multiscale_scales
if not squared_channel_joining
else None,
data_column=data_column,
no_mf_weighting=no_mf_weighting,
no_update_model_required=no_update_model_required,
Expand Down Expand Up @@ -377,7 +409,11 @@ def image_beam(

@task(name="Make Cube")
def make_cube(
pol: str, image_set: ImageSet, common_beam_pkl: Path, aux_mode: Optional[str] = None
pol: str,
image_set: ImageSet,
common_beam_pkl: Path,
pol_angle_deg: float,
aux_mode: Optional[str] = None,
) -> Tuple[Path, Path]:
"""Make a cube from the images"""
logger = get_run_logger()
Expand All @@ -392,6 +428,12 @@ def make_cube(
new_header = hdu_list[0].header
data_cube = hdu_list[0].data

# Add pol angle to header
new_header["INSTRUMENT_RECEPTOR_ANGLE"] = (
pol_angle_deg,
"Orig. pol. axis rotation angle in degrees",
)

tmp_header = new_header.copy()
# Need to swap NAXIS 3 and 4 to make LINMOS happy - booo
for a, b in ((3, 4), (4, 3)):
Expand Down Expand Up @@ -647,6 +689,8 @@ def main(
data_column: str = "CORRECTED_DATA",
skip_fix_ms: bool = False,
no_mf_weighting: bool = False,
disable_pol_local_rms: bool = False,
disable_pol_force_mask_rounds: bool = False,
):
"""Arrakis imager flow
Expand Down Expand Up @@ -685,6 +729,8 @@ def main(
data_column (str, optional): Data column to image. Defaults to "CORRECTED_DATA".
skip_fix_ms (bool, optional): Apply FixMS. Defaults to False.
no_mf_weighting (bool, optional): WSClean no_mf_weighting. Defaults to False.
disable_pol_local_rms (bool, optional): Disable local RMS for polarisation images. Defaults to False.
disable_pol_force_mask_rounds (bool, optional): Disable force mask rounds for polarisation images. Defaults to False.
"""

simage = get_wsclean(wsclean=wsclean_path)
Expand Down Expand Up @@ -726,8 +772,12 @@ def main(
ms_fix = fix_ms_askap_corrs(
ms=ms_fix, data_column="DATA", corrected_data_column=data_column
)
pol_angle_deg = (
get_pol_axis(ms_fix, col="INSTRUMENT_RECEPTOR_ANGLE").to(u.deg).value
)
else:
ms_fix = ms
pol_angle_deg = get_pol_axis(ms_fix, col="RECEPTOR_ANGLE").to(u.deg).value
# Image with wsclean
image_set = image_beam.submit(
ms=ms_fix,
Expand Down Expand Up @@ -760,6 +810,8 @@ def main(
absmem=absmem,
data_column=data_column,
no_mf_weighting=no_mf_weighting,
disable_pol_local_rms=disable_pol_local_rms,
disable_pol_force_mask_rounds=disable_pol_force_mask_rounds,
)

# Compute the smallest beam that all images can be convolved to.
Expand Down Expand Up @@ -790,6 +842,7 @@ def main(
pol=pol,
image_set=sm_image_set,
common_beam_pkl=common_beam_pkl,
pol_angle_deg=pol_angle_deg,
aux_mode=aux_mode,
wait_for=[sm_image_set],
)
Expand Down Expand Up @@ -1017,6 +1070,16 @@ def imager_parser(parent_parser: bool = False) -> argparse.ArgumentParser:
help="Number of beams to image",
default=36,
)
parser.add_argument(
"--disable_pol_local_rms",
action="store_true",
help="Disable local RMS for polarisation images",
)
parser.add_argument(
"--disable_pol_force_mask_rounds",
action="store_true",
help="Disable force mask rounds for polarisation images",
)

group = parser.add_argument_group("wsclean container options")
mxg = group.add_mutually_exclusive_group()
Expand Down Expand Up @@ -1062,6 +1125,9 @@ def cli():
scale=args.scale,
mgain=args.mgain,
niter=args.niter,
nmiter=args.nmiter,
local_rms=args.local_rms,
local_rms_window=args.local_rms_window,
auto_mask=args.auto_mask,
force_mask_rounds=args.force_mask_rounds,
auto_threshold=args.auto_threshold,
Expand All @@ -1080,6 +1146,8 @@ def cli():
data_column=args.data_column,
skip_fix_ms=args.skip_fix_ms,
no_mf_weighting=args.no_mf_weighting,
disable_pol_local_rms=args.disable_pol_local_rms,
disable_pol_force_mask_rounds=args.disable_pol_force_mask_rounds,
)


Expand Down
20 changes: 19 additions & 1 deletion arrakis/linmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,15 @@
from glob import glob
from pathlib import Path
from pprint import pformat
from typing import Dict, List, Optional, Tuple
from typing import Dict, List
from typing import NamedTuple as Struct
from typing import Optional, Tuple

import astropy.units as u
import numpy as np
import pandas as pd
import pymongo
from astropy.io import fits
from astropy.utils.exceptions import AstropyWarning
from prefect import flow, task
from racs_tools import beamcon_3D
Expand Down Expand Up @@ -160,6 +164,19 @@ def genparset(
"""
logger.setLevel(logging.INFO)

pol_angles_list: List[float] = []
for im in image_paths.images:
_pol_angle: float = fits.getheader(im)["INSTRUMENT_RECEPTOR_ANGLE"]
pol_angles_list.append(_pol_angle)
pol_angles: u.Quantity = pol_angles_list * u.deg

pol_0: u.Quantity = pol_angles[0]

assert np.allclose(pol_angles, pol_0), "Polarisation angles are not the same!"

logger.warning("Assuming holography was done at -45 degrees")
alpha = pol_0 - -45 * u.deg

image_string = f"[{','.join([im.resolve().with_suffix('').as_posix() for im in image_paths.images])}]"
weight_string = f"[{','.join([im.resolve().with_suffix('').as_posix() for im in image_paths.weights])}]"

Expand Down Expand Up @@ -188,6 +205,7 @@ def genparset(
parset += f"""
linmos.primarybeam = ASKAP_PB
linmos.primarybeam.ASKAP_PB.image = {holofile.resolve().as_posix()}
linmos.primarybeamASKAP_PB.alpha = {alpha.to(u.rad).value}
linmos.removeleakage = true
"""
else:
Expand Down
Loading

0 comments on commit 6999f2e

Please sign in to comment.