From e3e61acaa7b659b0934be6f4a6e1dc2a6fbc75b2 Mon Sep 17 00:00:00 2001 From: Alec Thomson Date: Wed, 18 Sep 2024 15:52:22 +0800 Subject: [PATCH] Divergence (#79) * Add exception * Add divergence retry --- arrakis/imager.py | 92 +++++++++++++++++++++++-------------- arrakis/utils/exceptions.py | 4 ++ 2 files changed, 61 insertions(+), 35 deletions(-) diff --git a/arrakis/imager.py b/arrakis/imager.py index 13f7433e..7360e48b 100644 --- a/arrakis/imager.py +++ b/arrakis/imager.py @@ -24,7 +24,6 @@ from astropy.visualization import ( SqrtStretch, ImageNormalize, - MinMaxInterval, ) from fitscube import combine_fits from fixms.fix_ms_corrs import fix_ms_corrs @@ -52,6 +51,8 @@ workdir_arg_parser, ) +from arrakis.utils.exceptions import DivergenceError + matplotlib.use("Agg") TQDM_OUT = TqdmToLogger(logger, level=logging.INFO) @@ -147,11 +148,6 @@ def get_mfs_image( if pol == "I" else f"{prefix_str}-MFS-{pol}-image.fits" ) - mfs_model_name = ( - f"{prefix_str}-MFS-model.fits" - if pol == "I" - else f"{prefix_str}-MFS-{pol}-model.fits" - ) mfs_residual_name = ( f"{prefix_str}-MFS-residual.fits" if pol == "I" @@ -159,8 +155,8 @@ def get_mfs_image( ) big_image = fits.getdata(mfs_image_name).squeeze() - big_model = fits.getdata(mfs_model_name).squeeze() big_residual = fits.getdata(mfs_residual_name).squeeze() + big_model = big_image - big_residual small_image = resize(big_image, small_size) small_model = resize(big_model, small_size) @@ -181,14 +177,11 @@ def make_validation_plots(prefix: Path, pols: str) -> None: for stokes in pols: mfs_image = get_mfs_image(prefix_str, stokes) fig, axs = plt.subplots(1, 3, figsize=(15, 5)) - for ax, sub_image, title in zip(axs, mfs_image, ("Image", "Model", "Residual")): + for ax, sub_image, title in zip( + axs, mfs_image, ("Image", "Model (conv.)", "Residual") + ): sub_image = np.abs(sub_image) - if title == "Model": - norm = ImageNormalize( - sub_image, interval=MinMaxInterval(), stretch=SqrtStretch() - ) - else: - norm = ImageNormalize(mfs_image.residual, vmin=0, stretch=SqrtStretch()) + norm = ImageNormalize(mfs_image.residual, vmin=0, stretch=SqrtStretch()) _ = ax.imshow(sub_image, origin="lower", norm=norm, cmap="cubehelix") ax.set_title(title) ax.get_yaxis().set_visible(False) @@ -280,6 +273,38 @@ def get_prefix( return out_dir / prefix +def run_wsclean_singuarlity( + command: str, + simage: Path, + out_dir: Path, + root_dir: Path, +) -> None: + logger.info(f"Running wsclean with command: {command}") + try: + output = sclient.execute( + image=simage.resolve(strict=True).as_posix(), + command=command.split(), + bind=f"{out_dir}:{out_dir}, {root_dir.resolve(strict=True).as_posix()}:{root_dir.resolve(strict=True).as_posix()}", + return_result=True, + quiet=False, + stream=True, + ) + for line in output: + logger.info(line.rstrip()) + # Catch divergence - look for the string 'KJy' in the output + if "KJy" in line: + raise DivergenceError( + f"Detected divergence in wsclean output: {line.rstrip()}" + ) + + except CalledProcessError as e: + logger.error(f"Failed to run wsclean with command: {command}") + logger.error(f"Stdout: {e.stdout}") + logger.error(f"Stderr: {e.stderr}") + logger.error(f"{e=}") + raise e + + @task(name="Image Beam", persist_result=True) def image_beam( ms: Path, @@ -423,29 +448,26 @@ def image_beam( ) root_dir = ms.parent - logger.info(f"Running wsclean with command: {command}") try: - output = sclient.execute( - image=simage.resolve(strict=True).as_posix(), - command=command.split(), - bind=f"{out_dir}:{out_dir}, {root_dir.resolve(strict=True).as_posix()}:{root_dir.resolve(strict=True).as_posix()}", - return_result=True, - quiet=False, - stream=True, + run_wsclean_singuarlity( + command=command, + simage=simage, + out_dir=out_dir, + root_dir=root_dir, + ) + except DivergenceError as de: + logger.error(f"Detected divergence in wsclean output: {de}") + new_pix = npix + 1024 + new_command = command.replace(f"{npix} {npix}", f"{new_pix} {new_pix}") + logger.critical( + f"Rerunning wsclean with larger image size: {new_pix}x{new_pix}" + ) + run_wsclean_singuarlity( + command=new_command, + simage=simage, + out_dir=out_dir, + root_dir=root_dir, ) - for line in output: - logger.info(line.rstrip()) - # Catch divergence - look for the string 'KJy' in the output - if "KJy" in line: - raise ValueError( - f"Detected divergence in wsclean output: {line.rstrip()}" - ) - except CalledProcessError as e: - logger.error(f"Failed to run wsclean with command: {command}") - logger.error(f"Stdout: {e.stdout}") - logger.error(f"Stderr: {e.stderr}") - logger.error(f"{e=}") - raise e # Purge ms_temp shutil.rmtree(ms_temp) diff --git a/arrakis/utils/exceptions.py b/arrakis/utils/exceptions.py index 49b22de7..356e8224 100644 --- a/arrakis/utils/exceptions.py +++ b/arrakis/utils/exceptions.py @@ -34,3 +34,7 @@ class ReadError(OSError): class RegistryError(Exception): """Raised when a registry operation with the archiving and unpacking registeries fails""" + + +class DivergenceError(Exception): + """Raised when a divergence is detected"""