Skip to content

Commit

Permalink
Divergence (#79)
Browse files Browse the repository at this point in the history
* Add exception

* Add divergence retry
  • Loading branch information
AlecThomson authored Sep 18, 2024
1 parent 0966b7f commit e3e61ac
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 35 deletions.
92 changes: 57 additions & 35 deletions arrakis/imager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from astropy.visualization import (
SqrtStretch,
ImageNormalize,
MinMaxInterval,
)
from fitscube import combine_fits
from fixms.fix_ms_corrs import fix_ms_corrs
Expand Down Expand Up @@ -52,6 +51,8 @@
workdir_arg_parser,
)

from arrakis.utils.exceptions import DivergenceError

matplotlib.use("Agg")

TQDM_OUT = TqdmToLogger(logger, level=logging.INFO)
Expand Down Expand Up @@ -147,20 +148,15 @@ def get_mfs_image(
if pol == "I"
else f"{prefix_str}-MFS-{pol}-image.fits"
)
mfs_model_name = (
f"{prefix_str}-MFS-model.fits"
if pol == "I"
else f"{prefix_str}-MFS-{pol}-model.fits"
)
mfs_residual_name = (
f"{prefix_str}-MFS-residual.fits"
if pol == "I"
else f"{prefix_str}-MFS-{pol}-residual.fits"
)

big_image = fits.getdata(mfs_image_name).squeeze()
big_model = fits.getdata(mfs_model_name).squeeze()
big_residual = fits.getdata(mfs_residual_name).squeeze()
big_model = big_image - big_residual

small_image = resize(big_image, small_size)
small_model = resize(big_model, small_size)
Expand All @@ -181,14 +177,11 @@ def make_validation_plots(prefix: Path, pols: str) -> None:
for stokes in pols:
mfs_image = get_mfs_image(prefix_str, stokes)
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
for ax, sub_image, title in zip(axs, mfs_image, ("Image", "Model", "Residual")):
for ax, sub_image, title in zip(
axs, mfs_image, ("Image", "Model (conv.)", "Residual")
):
sub_image = np.abs(sub_image)
if title == "Model":
norm = ImageNormalize(
sub_image, interval=MinMaxInterval(), stretch=SqrtStretch()
)
else:
norm = ImageNormalize(mfs_image.residual, vmin=0, stretch=SqrtStretch())
norm = ImageNormalize(mfs_image.residual, vmin=0, stretch=SqrtStretch())
_ = ax.imshow(sub_image, origin="lower", norm=norm, cmap="cubehelix")
ax.set_title(title)
ax.get_yaxis().set_visible(False)
Expand Down Expand Up @@ -280,6 +273,38 @@ def get_prefix(
return out_dir / prefix


def run_wsclean_singuarlity(
command: str,
simage: Path,
out_dir: Path,
root_dir: Path,
) -> None:
logger.info(f"Running wsclean with command: {command}")
try:
output = sclient.execute(
image=simage.resolve(strict=True).as_posix(),
command=command.split(),
bind=f"{out_dir}:{out_dir}, {root_dir.resolve(strict=True).as_posix()}:{root_dir.resolve(strict=True).as_posix()}",
return_result=True,
quiet=False,
stream=True,
)
for line in output:
logger.info(line.rstrip())
# Catch divergence - look for the string 'KJy' in the output
if "KJy" in line:
raise DivergenceError(
f"Detected divergence in wsclean output: {line.rstrip()}"
)

except CalledProcessError as e:
logger.error(f"Failed to run wsclean with command: {command}")
logger.error(f"Stdout: {e.stdout}")
logger.error(f"Stderr: {e.stderr}")
logger.error(f"{e=}")
raise e


@task(name="Image Beam", persist_result=True)
def image_beam(
ms: Path,
Expand Down Expand Up @@ -423,29 +448,26 @@ def image_beam(
)

root_dir = ms.parent
logger.info(f"Running wsclean with command: {command}")
try:
output = sclient.execute(
image=simage.resolve(strict=True).as_posix(),
command=command.split(),
bind=f"{out_dir}:{out_dir}, {root_dir.resolve(strict=True).as_posix()}:{root_dir.resolve(strict=True).as_posix()}",
return_result=True,
quiet=False,
stream=True,
run_wsclean_singuarlity(
command=command,
simage=simage,
out_dir=out_dir,
root_dir=root_dir,
)
except DivergenceError as de:
logger.error(f"Detected divergence in wsclean output: {de}")
new_pix = npix + 1024
new_command = command.replace(f"{npix} {npix}", f"{new_pix} {new_pix}")
logger.critical(
f"Rerunning wsclean with larger image size: {new_pix}x{new_pix}"
)
run_wsclean_singuarlity(
command=new_command,
simage=simage,
out_dir=out_dir,
root_dir=root_dir,
)
for line in output:
logger.info(line.rstrip())
# Catch divergence - look for the string 'KJy' in the output
if "KJy" in line:
raise ValueError(
f"Detected divergence in wsclean output: {line.rstrip()}"
)
except CalledProcessError as e:
logger.error(f"Failed to run wsclean with command: {command}")
logger.error(f"Stdout: {e.stdout}")
logger.error(f"Stderr: {e.stderr}")
logger.error(f"{e=}")
raise e

# Purge ms_temp
shutil.rmtree(ms_temp)
Expand Down
4 changes: 4 additions & 0 deletions arrakis/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,7 @@ class ReadError(OSError):
class RegistryError(Exception):
"""Raised when a registry operation with the archiving
and unpacking registeries fails"""


class DivergenceError(Exception):
"""Raised when a divergence is detected"""

0 comments on commit e3e61ac

Please sign in to comment.