From 6f29c5a29f4699774d4f50fa7ae4fff7b80cf5fe Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Mon, 11 Dec 2023 10:19:13 +0800 Subject: [PATCH 01/46] stuffs --- arrakis/cutout.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index 7c0bc9bf..d50247a2 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -21,6 +21,7 @@ from dask import delayed from dask.distributed import Client, LocalCluster from distributed import get_client +from prefect import flow, task from spectral_cube import SpectralCube from spectral_cube.utils import SpectralCubeWarning @@ -42,7 +43,7 @@ logger.setLevel(logging.INFO) -@delayed +@task(name="Cutout island") def cutout( image: str, src_name: str, @@ -166,7 +167,7 @@ def cutout( return ret -@delayed +@task(name="Get cutout arguments") def get_args( island: Dict, comps: List[Dict], @@ -278,7 +279,7 @@ def get_args( return args -@delayed +@task(name="Find components") def find_comps(island_id: str, comp_col: pymongo.collection.Collection) -> List[Dict]: """Find components for a given island @@ -293,7 +294,7 @@ def find_comps(island_id: str, comp_col: pymongo.collection.Collection) -> List[ return comps -@delayed +@task(name="Unpack list") def unpack(list_sq: List[List[Dict]]) -> List[Dict]: """Unpack list of lists @@ -310,6 +311,7 @@ def unpack(list_sq: List[List[Dict]]) -> List[Dict]: return list_fl +@flow(name="Cutout islands") def cutout_islands( field: str, directory: str, From c52288137ec57074a72eb22780b27881163296bf Mon Sep 17 00:00:00 2001 From: "Alec Thomson (S&A, Kensington WA)" Date: Mon, 11 Dec 2023 10:58:19 +0800 Subject: [PATCH 02/46] Switch cutout to prefect --- arrakis/cutout.py | 204 +++++++++++++++++++++++----------------------- 1 file changed, 101 insertions(+), 103 deletions(-) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index d50247a2..365584c4 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -7,7 +7,9 @@ from glob import glob from pprint import pformat from shutil import copyfile -from typing import Dict, List, Union +from typing import Dict, List +from typing import NamedTuple as Struct +from typing import Optional, TypeVar, Union import astropy.units as u import numpy as np @@ -21,7 +23,7 @@ from dask import delayed from dask.distributed import Client, LocalCluster from distributed import get_client -from prefect import flow, task +from prefect import flow, task, unmapped from spectral_cube import SpectralCube from spectral_cube.utils import SpectralCubeWarning @@ -42,21 +44,37 @@ logger.setLevel(logging.INFO) +T = TypeVar("T") + + +class CutoutArgs(Struct): + """Arguments for cutout function""" + + image: str + """Name of the image file""" + source_id: str + """Name of the source""" + ra_high: float + """Upper RA bound in degrees""" + ra_low: float + """Lower RA bound in degrees""" + dec_high: float + """Upper DEC bound in degrees""" + dec_low: float + """Lower DEC bound in degrees""" + outdir: str + """Output directory""" + beam: int + """Beam number""" + stoke: str + """Stokes parameter""" + @task(name="Cutout island") def cutout( - image: str, - src_name: str, - beam: int, - ra_hi: float, - ra_lo: float, - dec_hi: float, - dec_lo: float, - outdir: str, - stoke: str, + cutout_args: CutoutArgs, field: str, pad=3, - verbose=False, dryrun=False, ) -> List[pymongo.UpdateOne]: """Perform a cutout. @@ -65,10 +83,10 @@ def cutout( image (str): Name of the image file src_name (str): Name of the RACS source beam (int): Beam number - ra_hi (float): Upper RA bound - ra_lo (float): Lower RA bound - dec_hi (float): Upper DEC bound - dec_lo (float): Lower DEC bound + ra_high (float): Upper RA bound + ra_low (float): Lower RA bound + dec_high (float): Upper DEC bound + dec_low (float): Lower DEC bound outdir (str): Output directgory stoke (str): Stokes parameter field (str): RACS field name @@ -83,20 +101,20 @@ def cutout( # logger = logging.getLogger('distributed.worker') # logger = get_run_logger() - logger.info(f"Timwashere - {image=}") + logger.info(f"Timwashere - {cutout_args.image=}") - outdir = os.path.abspath(outdir) + outdir = os.path.abspath(cutout_args.outdir) ret = [] for imtype in ["image", "weight"]: - basename = os.path.basename(image) - outname = f"{src_name}.cutout.{basename}" + basename = os.path.basename(cutout_args.image) + outname = f"{cutout_args.source_id}.cutout.{basename}" outfile = os.path.join(outdir, outname) if imtype == "weight": - image = image.replace("image.restored", "weights.restored").replace( - ".fits", ".txt" - ) + image = cutout_args.image.replace( + "image.restored", "weights.restored" + ).replace(".fits", ".txt") outfile = outfile.replace("image.restored", "weights.restored").replace( ".fits", ".txt" ) @@ -110,10 +128,10 @@ def cutout( cube = SpectralCube.read(image) padder = cube.header["BMAJ"] * u.deg * pad - xlo = Longitude(ra_lo * u.deg) - Longitude(padder) - xhi = Longitude(ra_hi * u.deg) + Longitude(padder) - ylo = Latitude(dec_lo * u.deg) - Latitude(padder) - yhi = Latitude(dec_hi * u.deg) + Latitude(padder) + xlo = Longitude(cutout_args.ra_low * u.deg) - Longitude(padder) + xhi = Longitude(cutout_args.ra_high * u.deg) + Longitude(padder) + ylo = Latitude(cutout_args.dec_low * u.deg) - Latitude(padder) + yhi = Latitude(cutout_args.dec_high * u.deg) + Latitude(padder) xp_lo, yp_lo = skycoord_to_pixel(SkyCoord(xlo, ylo), cube.wcs) xp_hi, yp_hi = skycoord_to_pixel(SkyCoord(xhi, yhi), cube.wcs) @@ -141,7 +159,7 @@ def cutout( ] fixed_header = fix_header(new_header, old_header) # Add source name to header for CASDA - fixed_header["OBJECT"] = src_name + fixed_header["OBJECT"] = cutout_args.source_id if not dryrun: fits.writeto( outfile, @@ -153,13 +171,15 @@ def cutout( logger.info(f"Written to {outfile}") # Update database - myquery = {"Source_ID": src_name} + myquery = {"Source_ID": cutout_args.source_id} filename = os.path.join( os.path.basename(os.path.dirname(outfile)), os.path.basename(outfile) ) newvalues = { - "$set": {f"beams.{field}.{stoke}_beam{beam}_{imtype}_file": filename} + "$set": { + f"beams.{field}.{cutout_args.stoke}_beam{cutout_args.beam}_{imtype}_file": filename + } } ret += [pymongo.UpdateOne(myquery, newvalues, upsert=True)] @@ -178,7 +198,7 @@ def get_args( datadir: str, stokeslist: List[str], verbose=True, -) -> List[Dict]: +) -> Union[List[CutoutArgs], None]: """Get arguments for cutout function Args: @@ -197,7 +217,7 @@ def get_args( Exception: Problems with coordinates Returns: - List[Dict]: List of cutout arguments for cutout function + List[CutoutArgs]: List of cutout arguments for cutout function """ logger.setLevel(logging.INFO) @@ -205,6 +225,10 @@ def get_args( assert island["Source_ID"] == island_id assert beam["Source_ID"] == island_id + if len(comps) == 0: + logger.warning(f"Skipping island {island_id} -- no components found") + return None + beam_list = list(set(beam["beams"][field]["beam_list"])) outdir = f"{outdir}/{island['Source_ID']}" @@ -227,22 +251,22 @@ def get_args( ra_max = np.max(coords.ra) ra_i_max = np.argmax(coords.ra) ra_off = Longitude(majs[ra_i_max]) - ra_hi = ra_max + ra_off + ra_high = ra_max + ra_off ra_min = np.min(coords.ra) ra_i_min = np.argmin(coords.ra) ra_off = Longitude(majs[ra_i_min]) - ra_lo = ra_min - ra_off + ra_low = ra_min - ra_off dec_max = np.max(coords.dec) dec_i_max = np.argmax(coords.dec) dec_off = Longitude(majs[dec_i_max]) - dec_hi = dec_max + dec_off + dec_high = dec_max + dec_off dec_min = np.min(coords.dec) dec_i_min = np.argmin(coords.dec) dec_off = Longitude(majs[dec_i_min]) - dec_lo = dec_min - dec_off + dec_low = dec_min - dec_off except Exception as e: logger.debug(f"coords are {coords=}") logger.debug(f"comps are {comps=}") @@ -262,19 +286,17 @@ def get_args( for image in images: args.extend( - [ - { - "image": image, - "id": island["Source_ID"], - "ra_hi": ra_hi.deg, - "ra_lo": ra_lo.deg, - "dec_hi": dec_hi.deg, - "dec_lo": dec_lo.deg, - "outdir": outdir, - "beam": beam_num, - "stoke": stoke.lower(), - } - ] + CutoutArgs( + image=image, + id=island["Source_ID"], + ra_high=ra_high.deg, + ra_low=ra_low.deg, + dec_high=dec_high.deg, + dec_low=dec_low.deg, + outdir=outdir, + beam=beam_num, + stoke=stoke.lower(), + ) ) return args @@ -295,17 +317,19 @@ def find_comps(island_id: str, comp_col: pymongo.collection.Collection) -> List[ @task(name="Unpack list") -def unpack(list_sq: List[List[Dict]]) -> List[Dict]: - """Unpack list of lists +def unpack(list_sq: List[List[T]]) -> List[T]: + """Unpack list of lists of things into a list of things Args: - list_sq (List[List[Dict]]): List of lists of dicts + list_sq (List[List[T]]): List of lists of things Returns: - List[Dict]: List of dicts + List[T]: List of things """ list_fl = [] for i in list_sq: + if i is None: + continue for j in i: list_fl.append(j) return list_fl @@ -317,10 +341,10 @@ def cutout_islands( directory: str, host: str, epoch: int, - username: Union[str, None] = None, - password: Union[str, None] = None, + username: Optional[str] = None, + password: Optional[str] = None, pad: float = 3, - stokeslist: Union[List[str], None] = None, + stokeslist: Optional[List[str]] = None, verbose_worker: bool = False, dryrun: bool = True, ) -> None: @@ -379,55 +403,29 @@ def cutout_islands( # Create output dir if it doesn't exist try_mkdir(outdir) - args = [] - for island_id, island, comp, beam in zip(island_ids, islands, comps, beams): - if len(comp) == 0: - warnings.warn(f"Skipping island {island_id} -- no components found") - continue - else: - arg = get_args( - island, - comp, - beam, - island_id, - outdir, - field, - directory, - stokeslist, - verbose=verbose_worker, - ) - args.append(arg) - - flat_args = unpack(args) - flat_args = client.compute(flat_args) - tqdm_dask(flat_args, desc="Getting args", total=len(islands) + 1) - flat_args = flat_args.result() - cuts = [] - for arg in flat_args: - cut = cutout( - image=arg["image"], - src_name=arg["id"], - beam=arg["beam"], - ra_hi=arg["ra_hi"], - ra_lo=arg["ra_lo"], - dec_hi=arg["dec_hi"], - dec_lo=arg["dec_lo"], - outdir=arg["outdir"], - stoke=arg["stoke"], - field=field, - pad=pad, - verbose=verbose_worker, - dryrun=dryrun, - ) - cuts.append(cut) + args = get_args.map( + island=islands, + comps=comps, + beam=beams, + island_id=island_ids, + outdir=unmapped(outdir), + field=unmapped(field), + datadir=unmapped(directory), + stokeslist=unmapped(stokeslist), + verbose=unmapped(verbose_worker), + ) - futures = chunk_dask( - outputs=cuts, - task_name="cutouts", - progress_text="Cutting out", + flat_args = unpack.map(args) + + cuts = cutout.map( + cutout_args=flat_args, + field=unmapped(field), + pad=unmapped(pad), + dryrun=unmapped(dryrun), ) + if not dryrun: - _updates = [f.compute() for f in futures] + _updates = [f.result() for f in cuts] updates = [val for sublist in _updates for val in sublist] logger.info("Updating database...") db_res = beams_col.bulk_write(updates, ordered=False) @@ -436,7 +434,7 @@ def cutout_islands( logger.info("Cutouts Done!") -def main(args: argparse.Namespace, verbose=True) -> None: +def main(args: argparse.Namespace) -> None: """Main script Args: From 7066da80116e182744ef502f6e1a9864ae468f8c Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 14:25:14 +1100 Subject: [PATCH 03/46] Fixes --- arrakis/cutout.py | 2 +- arrakis/process_spice.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index 365584c4..fc9ebde5 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -288,7 +288,7 @@ def get_args( args.extend( CutoutArgs( image=image, - id=island["Source_ID"], + source_id=island["Source_ID"], ra_high=ra_high.deg, ra_low=ra_low.deg, dec_high=dec_high.deg, diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index 81db0567..b920961e 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -30,7 +30,7 @@ from arrakis.utils.pipeline import logo_str, performance_report_prefect # Defining tasks -cut_task = task(cutout.cutout_islands, name="Cutout") +# cut_task = task(cutout.cutout_islands, name="Cutout") linmos_task = task(linmos.main, name="LINMOS") frion_task = task(frion.main, name="FRion") cleanup_task = task(cleanup.main, name="Clean up") @@ -53,9 +53,9 @@ def process_spice(args, host: str) -> None: with get_dask_client(): previous_future = None previous_future = ( - cut_task.submit( + cutout.cutout_islands( field=args.field, - directory=args.outdir, + directory=str(args.outdir), host=host, epoch=args.epoch, username=args.username, From 80a303e521daf4a64ab15c66d37f39e0eed13466 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 14:25:48 +1100 Subject: [PATCH 04/46] Cleanup --- arrakis/process_spice.py | 1 - 1 file changed, 1 deletion(-) diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index b920961e..5e625098 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -30,7 +30,6 @@ from arrakis.utils.pipeline import logo_str, performance_report_prefect # Defining tasks -# cut_task = task(cutout.cutout_islands, name="Cutout") linmos_task = task(linmos.main, name="LINMOS") frion_task = task(frion.main, name="FRion") cleanup_task = task(cleanup.main, name="Clean up") From 4d218769e85c44427fe1e46170e927d8e56c51b6 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 14:49:00 +1100 Subject: [PATCH 05/46] Fix listy issues --- arrakis/cutout.py | 39 +++++++++++++++++++++++---------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index fc9ebde5..2ddd49fd 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -272,7 +272,7 @@ def get_args( logger.debug(f"comps are {comps=}") raise e - args = [] + args: List[CutoutArgs] = [] for beam_num in beam_list: for stoke in stokeslist: wild = f"{datadir}/image.restored.{stoke.lower()}*contcube*beam{beam_num:02}.conv.fits" @@ -285,18 +285,20 @@ def get_args( ) for image in images: - args.extend( - CutoutArgs( - image=image, - source_id=island["Source_ID"], - ra_high=ra_high.deg, - ra_low=ra_low.deg, - dec_high=dec_high.deg, - dec_low=dec_low.deg, - outdir=outdir, - beam=beam_num, - stoke=stoke.lower(), - ) + args.append( + [ + CutoutArgs( + image=image, + source_id=island["Source_ID"], + ra_high=ra_high.deg, + ra_low=ra_low.deg, + dec_high=dec_high.deg, + dec_low=dec_low.deg, + outdir=outdir, + beam=beam_num, + stoke=stoke.lower(), + ) + ] ) return args @@ -317,21 +319,26 @@ def find_comps(island_id: str, comp_col: pymongo.collection.Collection) -> List[ @task(name="Unpack list") -def unpack(list_sq: List[List[T]]) -> List[T]: +def unpack(list_sq: List[List[T] | None]) -> List[T]: """Unpack list of lists of things into a list of things + Skips None entries Args: - list_sq (List[List[T]]): List of lists of things + list_sq (List[List[T] | None]): List of lists of things or Nones Returns: List[T]: List of things """ - list_fl = [] + list_fl: List[T] = [] for i in list_sq: if i is None: continue + if isinstance(i, list): + list_fl.extend(i) + continue for j in i: list_fl.append(j) + return list_fl From 09ba4438308c2961590b9b0e27e37f84b313ffa4 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 16:00:42 +1100 Subject: [PATCH 06/46] Try an extend --- arrakis/cutout.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index 2ddd49fd..775564cc 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -285,20 +285,18 @@ def get_args( ) for image in images: - args.append( - [ - CutoutArgs( - image=image, - source_id=island["Source_ID"], - ra_high=ra_high.deg, - ra_low=ra_low.deg, - dec_high=dec_high.deg, - dec_low=dec_low.deg, - outdir=outdir, - beam=beam_num, - stoke=stoke.lower(), - ) - ] + args.extend( + CutoutArgs( + image=image, + source_id=island["Source_ID"], + ra_high=ra_high.deg, + ra_low=ra_low.deg, + dec_high=dec_high.deg, + dec_low=dec_low.deg, + outdir=outdir, + beam=beam_num, + stoke=stoke.lower(), + ) ) return args @@ -319,7 +317,7 @@ def find_comps(island_id: str, comp_col: pymongo.collection.Collection) -> List[ @task(name="Unpack list") -def unpack(list_sq: List[List[T] | None]) -> List[T]: +def unpack(list_sq: List[Union[List[T], None]]) -> List[T]: """Unpack list of lists of things into a list of things Skips None entries From d9da00e53e4cb7241ff50207a8d8a4ee82e01b19 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 16:20:41 +1100 Subject: [PATCH 07/46] Use append --- arrakis/cutout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index 775564cc..6eca6420 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -285,7 +285,7 @@ def get_args( ) for image in images: - args.extend( + args.append( CutoutArgs( image=image, source_id=island["Source_ID"], From adf26d09a9b5c0f70a1ab52033eb60a294ab6c81 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 16:21:05 +1100 Subject: [PATCH 08/46] Use prefect --- arrakis/linmos.py | 215 ++++++++++++++++++++++++++++------------------ 1 file changed, 131 insertions(+), 84 deletions(-) diff --git a/arrakis/linmos.py b/arrakis/linmos.py index 1c0ca380..b5296dc5 100644 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -7,12 +7,15 @@ from glob import glob from pathlib import Path from pprint import pformat -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List +from typing import NamedTuple as Struct +from typing import Optional, Tuple, Union import pymongo from astropy.utils.exceptions import AstropyWarning from dask import delayed from dask.distributed import Client, LocalCluster +from prefect import flow, task, unmapped from racs_tools import beamcon_3D from spectral_cube.utils import SpectralCubeWarning from spython.main import Client as sclient @@ -29,14 +32,22 @@ logger.setLevel(logging.INFO) -@delayed(nout=2) +class ImagePaths(Struct): + """Class to hold image paths""" + + images: List[Path] + """List of image paths""" + weights: List[Path] + """List of weight paths""" + + +@task(name="Find images") def find_images( field: str, - src_name: str, beams: dict, stoke: str, datadir: Path, -) -> Tuple[List[Path], List[Path]]: +) -> ImagePaths: """Find the images and weights for a given field and stokes parameter Args: @@ -50,9 +61,11 @@ def find_images( Exception: If no files are found. Returns: - Tuple[List[Path], List[Path]]: List of images and weights. + ImagePaths: List of images and weights. """ logger.setLevel(logging.INFO) + + src_name = beams["Source_ID"] field_beams = beams["beams"][field] # First check that the images exist @@ -82,29 +95,29 @@ def find_images( im.parent.name == wt.parent.name ), "Image and weight are in different areas!" - return image_list, weight_list + return ImagePaths(image_list, weight_list) -@delayed +@task(name="Smooth images") def smooth_images( - image_dict: Dict[str, List[Path]], -) -> Dict[str, List[Path]]: + image_dict: Dict[str, ImagePaths], +) -> Dict[str, ImagePaths]: """Smooth cubelets to a common resolution Args: - image_list (List[Path]): List of cubelets to smooth. + image_list (ImagePaths): List of cubelets to smooth. Returns: - List[Path]: Smoothed cubelets. + ImagePaths: Smoothed cubelets. """ - smooth_dict: Dict[str, List[Path]] = {} + smooth_dict: Dict[str, ImagePaths] = {} for stoke, image_list in image_dict.items(): infiles: List[str] = [] - for im in image_list: + for im in image_list.images: if im.suffix == ".fits": infiles.append(im.resolve().as_posix()) datadict = beamcon_3D.main( - infile=[im.resolve().as_posix() for im in image_list], + infile=[im.resolve().as_posix() for im in image_list.images], uselogs=False, mode="total", conv_mode="robust", @@ -113,15 +126,14 @@ def smooth_images( smooth_files: List[Path] = [] for key, val in datadict.items(): smooth_files.append(Path(val["outfile"])) - smooth_dict[stoke] = smooth_files + smooth_dict[stoke] = ImagePaths(smooth_files, image_list.weights) return smooth_dict -@delayed +@task(name="Generate parset") def genparset( - image_list: List[Path], - weight_list: List[Path], + image_paths: ImagePaths, stoke: str, datadir: Path, holofile: Optional[Path] = None, @@ -129,12 +141,10 @@ def genparset( """Generate parset for LINMOS Args: - field (str): RACS field name. - src_name (str): RACE source name. - beams (dict): Mongo entry for RACS beams. + image_paths (ImagePaths): List of images and weights. stoke (str): Stokes parameter. - datadir (str): Data directory. - holofile (str): Full path to holography file. + datadir (Path): Data directory. + holofile (Path, optional): Path to the holography file to include in the bind list. Defaults to None. Raises: Exception: If no files are found. @@ -144,17 +154,13 @@ def genparset( """ logger.setLevel(logging.INFO) - image_string = ( - f"[{','.join([im.resolve().with_suffix('').as_posix() for im in image_list])}]" - ) - weight_string = ( - f"[{','.join([im.resolve().with_suffix('').as_posix() for im in weight_list])}]" - ) + image_string = f"[{','.join([im.resolve().with_suffix('').as_posix() for im in image_paths.images])}]" + weight_string = f"[{','.join([im.resolve().with_suffix('').as_posix() for im in image_paths.weights])}]" - parset_dir = datadir.resolve() / image_list[0].parent.name + parset_dir = datadir.resolve() / image_paths.images[0].parent.name - first_image = image_list[0].resolve().with_suffix("").as_posix() - first_weight = weight_list[0].resolve().with_suffix("").as_posix() + first_image = image_paths.images[0].resolve().with_suffix("").as_posix() + first_weight = image_paths.weights[0].resolve().with_suffix("").as_posix() linmos_image_str = f"{first_image[:first_image.find('beam')]}linmos" linmos_weight_str = f"{first_weight[:first_weight.find('beam')]}linmos" @@ -187,9 +193,9 @@ def genparset( return parset_file -@delayed +@task(name="Run linmos") def linmos( - parset: str, fieldname: str, image: str, holofile: Path + parset: Optional[str], fieldname: str, image: str, holofile: Path ) -> pymongo.UpdateOne: """Run linmos @@ -209,9 +215,11 @@ def linmos( """ logger.setLevel(logging.INFO) + if parset is None: + return + workdir = os.path.dirname(parset) rootdir = os.path.split(workdir)[0] - junk = os.path.split(workdir)[-1] parset_name = os.path.basename(parset) source = os.path.basename(workdir) stoke = parset_name[parset_name.find(".in") - 1] @@ -267,6 +275,46 @@ def get_yanda(version="1.3.0") -> str: return image +@task(name="Component worker") +def component_worker( + beams: dict, + comp: List[dict], + stokeslist: List[str], + field: str, + cutdir: Path, + holofile: Optional[Path] = None, +) -> Union[List[str], None]: + src = beams["Source_ID"] + if len(comp) == 0: + logger.warn(f"Skipping island {src} -- no components found") + return + + image_dict: Dict[str, ImagePaths] = {} + for stoke in stokeslist: + image_paths = find_images( + field=field, + src_name=src, + beams=beams, + stoke=stoke.capitalize(), + datadir=cutdir, + ) + image_dict[stoke] = image_paths + + smooth_dict = smooth_images(image_dict) + parfiles: List[str] = [] + for stoke in stokeslist: + parfile = genparset( + image_paths=smooth_dict[stoke], + stoke=stoke.capitalize(), + datadir=cutdir, + holofile=holofile, + ) + parfiles.append(parfile) + + return parfiles + + +@flow(name="LINMOS") def main( field: str, datadir: Path, @@ -277,8 +325,7 @@ def main( password: Optional[str] = None, yanda: str = "1.3.0", yanda_img: Optional[Path] = None, - stokeslist: Union[List[str], None] = None, - verbose=True, + stokeslist: Optional[List[str]] = None, ) -> None: """Main script @@ -313,15 +360,14 @@ def main( logger.info(f"The query is {query=}") - island_ids = sorted(beams_col.distinct("Source_ID", query)) - big_beams = list( + island_ids: List[str] = sorted(beams_col.distinct("Source_ID", query)) + big_beams: List[dict] = list( beams_col.find({"Source_ID": {"$in": island_ids}}).sort("Source_ID") ) - # files = sorted([name for name in glob(f"{cutdir}/*") if os.path.isdir(os.path.join(cutdir, name))]) - big_comps = list( + big_comps: List[dict] = list( comp_col.find({"Source_ID": {"$in": island_ids}}).sort("Source_ID") ) - comps = [] + comps: List[List[dict]] = [] for island_id in island_ids: _comps = [] for c in big_comps: @@ -331,49 +377,51 @@ def main( assert len(big_beams) == len(comps) - parfiles = [] - for beams, comp in zip(big_beams, comps): - src = beams["Source_ID"] - if len(comp) == 0: - warnings.warn(f"Skipping island {src} -- no components found") - continue - else: - image_dict: Dict[str, List[Path]] = {} - weight_dict: Dict[str, List[Path]] = {} - for stoke in stokeslist: - image_list, weight_list = find_images( - field=field, - src_name=src, - beams=beams, - stoke=stoke.capitalize(), - datadir=cutdir, - ) - image_dict[stoke] = image_list - weight_dict[stoke] = weight_list - - smooth_dict = smooth_images(image_dict) - for stoke in stokeslist: - parfile = genparset( - image_list=smooth_dict[stoke], - weight_list=weight_dict[stoke], - stoke=stoke.capitalize(), - datadir=cutdir, - holofile=holofile, - ) - parfiles.append(parfile) - - results = [] - for parset in parfiles: - results.append(linmos(parset, field, str(image), holofile=holofile)) - - futures = chunk_dask( - outputs=results, - task_name="LINMOS", - progress_text="Runing LINMOS", - verbose=verbose, + # parfiles = [] + # for beams, comp in zip(big_beams, comps): + # src = beams["Source_ID"] + # if len(comp) == 0: + # logger.warn(f"Skipping island {src} -- no components found") + # continue + + # image_dict: Dict[str, ImagePaths] = {} + # for stoke in stokeslist: + # image_paths = find_images.submit( + # field=field, + # src_name=src, + # beams=beams, + # stoke=stoke.capitalize(), + # datadir=cutdir, + # ) + # image_dict[stoke] = image_paths + + # smooth_dict = smooth_images(image_dict) + # for stoke in stokeslist: + # parfile = genparset.submit( + # image_paths=smooth_dict[stoke], + # stoke=stoke.capitalize(), + # datadir=cutdir, + # holofile=holofile, + # ) + # parfiles.append(parfile) + parfiles = component_worker.map( + beams=big_beams, + comp=comps, + stokeslist=unmapped(stokeslist), + field=unmapped(field), + cutdir=unmapped(cutdir), + holofile=unmapped(holofile), + ) + + results = linmos.map( + parfiles, + unmapped(field), + unmapped(str(image)), + unmapped(holofile), ) - updates = [f.compute() for f in futures] + updates = [f.result() for f in results] + updates = [u for u in updates if u is not None] logger.info("Updating database...") db_res = beams_col.bulk_write(updates, ordered=False) logger.info(pformat(db_res.bulk_api_result)) @@ -480,7 +528,6 @@ def cli(): yanda=args.yanda, yanda_img=args.yanda_image, stokeslist=args.stokeslist, - verbose=verbose, ) From fda6d5c1968b31f713d9d5bbf515cb062571e982 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 16:21:16 +1100 Subject: [PATCH 09/46] Cleanup --- arrakis/utils/database.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arrakis/utils/database.py b/arrakis/utils/database.py index 2d2d44df..8c9b4a40 100644 --- a/arrakis/utils/database.py +++ b/arrakis/utils/database.py @@ -7,6 +7,7 @@ import pymongo from astropy.utils.exceptions import AstropyWarning from pymongo.collection import Collection +from pymongo.database import Database from spectral_cube.utils import SpectralCubeWarning from arrakis.logger import logger @@ -57,7 +58,7 @@ def get_db( epoch: int, username: Union[str, None] = None, password: Union[str, None] = None, -) -> Tuple[Collection, Collection, Collection,]: +) -> Tuple[Collection, Collection, Collection]: """Get MongoDBs Args: From 13b808961f830e6cc232fd3e477a1b2768a6a91f Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 16:24:32 +1100 Subject: [PATCH 10/46] Cleanup --- arrakis/linmos.py | 33 +++------------------------------ 1 file changed, 3 insertions(+), 30 deletions(-) diff --git a/arrakis/linmos.py b/arrakis/linmos.py index b5296dc5..8b906aa9 100644 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -196,14 +196,14 @@ def genparset( @task(name="Run linmos") def linmos( parset: Optional[str], fieldname: str, image: str, holofile: Path -) -> pymongo.UpdateOne: +) -> Optional[pymongo.UpdateOne]: """Run linmos Args: parset (str): Path to parset file. fieldname (str): Name of RACS field. image (str): Name of Yandasoft image. - holofile (Union[Path,str]): Path to the holography file to include in the bind list. + holofile (Path): Path to the holography file to include in the bind list. verbose (bool, optional): Verbose output. Defaults to False. Raises: @@ -283,7 +283,7 @@ def component_worker( field: str, cutdir: Path, holofile: Optional[Path] = None, -) -> Union[List[str], None]: +) -> Optional[List[str]]: src = beams["Source_ID"] if len(comp) == 0: logger.warn(f"Skipping island {src} -- no components found") @@ -377,33 +377,6 @@ def main( assert len(big_beams) == len(comps) - # parfiles = [] - # for beams, comp in zip(big_beams, comps): - # src = beams["Source_ID"] - # if len(comp) == 0: - # logger.warn(f"Skipping island {src} -- no components found") - # continue - - # image_dict: Dict[str, ImagePaths] = {} - # for stoke in stokeslist: - # image_paths = find_images.submit( - # field=field, - # src_name=src, - # beams=beams, - # stoke=stoke.capitalize(), - # datadir=cutdir, - # ) - # image_dict[stoke] = image_paths - - # smooth_dict = smooth_images(image_dict) - # for stoke in stokeslist: - # parfile = genparset.submit( - # image_paths=smooth_dict[stoke], - # stoke=stoke.capitalize(), - # datadir=cutdir, - # holofile=holofile, - # ) - # parfiles.append(parfile) parfiles = component_worker.map( beams=big_beams, comp=comps, From b3632d05b9ecee30f1559168f78846a4596a3c66 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 16:29:56 +1100 Subject: [PATCH 11/46] Cleanup --- arrakis/linmos.py | 1 - 1 file changed, 1 deletion(-) diff --git a/arrakis/linmos.py b/arrakis/linmos.py index 8b906aa9..d4200097 100644 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -238,7 +238,6 @@ def linmos( outstr = "\n".join(output["message"]) with open(log_file, "w") as f: f.write(outstr) - # f.write(output['message']) if output["return_code"] != 0: raise Exception(f"LINMOS failed! Check '{log_file}'") From 1b7bd4f8822f74cc51a8e87ae398d737381874ba Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 17:44:27 +1100 Subject: [PATCH 12/46] Precommit --- arrakis/cutout.py | 44 ++++++++++++++++++++++++++++------------ arrakis/process_spice.py | 1 + 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index 6eca6420..3ba47ba2 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -98,8 +98,6 @@ def cutout( pymongo.UpdateOne: Update query for MongoDB """ logger.setLevel(logging.INFO) - # logger = logging.getLogger('distributed.worker') - # logger = get_run_logger() logger.info(f"Timwashere - {cutout_args.image=}") @@ -122,10 +120,10 @@ def cutout( logger.info(f"Written to {outfile}") if imtype == "image": - logger.info(f"Reading {image}") + logger.info(f"Reading {cutout_args.image}") with warnings.catch_warnings(): warnings.simplefilter("ignore", AstropyWarning) - cube = SpectralCube.read(image) + cube = SpectralCube.read(cutout_args.image) padder = cube.header["BMAJ"] * u.deg * pad xlo = Longitude(cutout_args.ra_low * u.deg) - Longitude(padder) @@ -147,7 +145,9 @@ def cutout( new_header = cutout_cube.header with warnings.catch_warnings(): warnings.simplefilter("ignore", AstropyWarning) - with fits.open(image, memmap=True, mode="denywrite") as hdulist: + with fits.open( + cutout_args.image, memmap=True, mode="denywrite" + ) as hdulist: data = hdulist[0].data old_header = hdulist[0].header @@ -298,6 +298,7 @@ def get_args( stoke=stoke.lower(), ) ) + logger.info(f"{args=}") return args @@ -317,7 +318,7 @@ def find_comps(island_id: str, comp_col: pymongo.collection.Collection) -> List[ @task(name="Unpack list") -def unpack(list_sq: List[Union[List[T], None]]) -> List[T]: +def unpack(list_sq: List[Union[List[T], None, T]]) -> List[T]: """Unpack list of lists of things into a list of things Skips None entries @@ -327,16 +328,18 @@ def unpack(list_sq: List[Union[List[T], None]]) -> List[T]: Returns: List[T]: List of things """ + logger.setLevel(logging.DEBUG) + logger.debug(f"{list_sq=}") list_fl: List[T] = [] for i in list_sq: if i is None: continue - if isinstance(i, list): + elif isinstance(i, list): list_fl.extend(i) continue - for j in i: - list_fl.append(j) - + else: + list_fl.append(i) + logger.debug(f"{list_fl=}") return list_fl @@ -352,6 +355,7 @@ def cutout_islands( stokeslist: Optional[List[str]] = None, verbose_worker: bool = False, dryrun: bool = True, + limit: Optional[int] = None, ) -> None: """Perform cutouts of RACS islands in parallel. @@ -408,6 +412,13 @@ def cutout_islands( # Create output dir if it doesn't exist try_mkdir(outdir) + if limit is not None: + logger.critical(f"Limiting to {limit} islands") + islands = islands[:limit] + island_ids = island_ids[:limit] + comps = comps[:limit] + beams = beams[:limit] + args = get_args.map( island=islands, comps=comps, @@ -419,9 +430,9 @@ def cutout_islands( stokeslist=unmapped(stokeslist), verbose=unmapped(verbose_worker), ) - - flat_args = unpack.map(args) - + # args = [a.result() for a in args] + # flat_args = unpack.map(args) + flat_args = unpack(args) cuts = cutout.map( cutout_args=flat_args, field=unmapped(field), @@ -457,6 +468,7 @@ def main(args: argparse.Namespace) -> None: stokeslist=args.stokeslist, verbose_worker=args.verbose_worker, dryrun=args.dryrun, + limit=args.limit, ) logger.info("Done!") @@ -546,6 +558,12 @@ def cutout_parser(parent_parser: bool = False) -> argparse.ArgumentParser: type=str, help="List of Stokes parameters to image [ALL]", ) + parser.add_argument( + "--limit", + type=Optional[int], + default=None, + help="Limit number of islands to process [None]", + ) return cut_parser diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index 5e625098..a2d92fac 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -63,6 +63,7 @@ def process_spice(args, host: str) -> None: stokeslist=["I", "Q", "U"], verbose_worker=args.verbose_worker, dryrun=args.dryrun, + limit=args.limit, ) if not args.skip_cutout else previous_future From f4fe581945f2d5d496198fa70be6432c993205d5 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 18:16:50 +1100 Subject: [PATCH 13/46] Do linmos --- arrakis/linmos.py | 86 +++++++++++++++++----------------------- arrakis/process_spice.py | 6 +-- 2 files changed, 39 insertions(+), 53 deletions(-) diff --git a/arrakis/linmos.py b/arrakis/linmos.py index d4200097..b2a9de48 100644 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -52,7 +52,6 @@ def find_images( Args: field (str): Field name. - src_name (str): Source name. beams (dict): Beam information. stoke (str): Stokes parameter. datadir (Path): Data directory. @@ -274,45 +273,6 @@ def get_yanda(version="1.3.0") -> str: return image -@task(name="Component worker") -def component_worker( - beams: dict, - comp: List[dict], - stokeslist: List[str], - field: str, - cutdir: Path, - holofile: Optional[Path] = None, -) -> Optional[List[str]]: - src = beams["Source_ID"] - if len(comp) == 0: - logger.warn(f"Skipping island {src} -- no components found") - return - - image_dict: Dict[str, ImagePaths] = {} - for stoke in stokeslist: - image_paths = find_images( - field=field, - src_name=src, - beams=beams, - stoke=stoke.capitalize(), - datadir=cutdir, - ) - image_dict[stoke] = image_paths - - smooth_dict = smooth_images(image_dict) - parfiles: List[str] = [] - for stoke in stokeslist: - parfile = genparset( - image_paths=smooth_dict[stoke], - stoke=stoke.capitalize(), - datadir=cutdir, - holofile=holofile, - ) - parfiles.append(parfile) - - return parfiles - - @flow(name="LINMOS") def main( field: str, @@ -325,6 +285,7 @@ def main( yanda: str = "1.3.0", yanda_img: Optional[Path] = None, stokeslist: Optional[List[str]] = None, + limit: Optional[int] = None, ) -> None: """Main script @@ -376,17 +337,37 @@ def main( assert len(big_beams) == len(comps) - parfiles = component_worker.map( - beams=big_beams, - comp=comps, - stokeslist=unmapped(stokeslist), - field=unmapped(field), - cutdir=unmapped(cutdir), - holofile=unmapped(holofile), - ) + if limit is not None: + logger.critical(f"Limiting to {limit} islands") + big_beams = big_beams[:limit] + comps = comps[:limit] + + # parfiles = component_worker.map( + # beams=big_beams, + # comp=comps, + # stokeslist=unmapped(stokeslist), + # field=unmapped(field), + # cutdir=unmapped(cutdir), + # holofile=unmapped(holofile), + # ) + all_parfiles = [] + for stoke in stokeslist: + image_paths = find_images.map( + field=unmapped(field), + beams=big_beams, + stoke=unmapped(stoke.capitalize()), + datadir=unmapped(cutdir), + ) + parfiles = genparset.map( + image_paths=image_paths, + stoke=unmapped(stoke.capitalize()), + datadir=unmapped(cutdir), + holofile=unmapped(holofile), + ) + all_parfiles.extend(parfiles) results = linmos.map( - parfiles, + all_parfiles, unmapped(field), unmapped(str(image)), unmapped(holofile), @@ -478,6 +459,12 @@ def cli(): parser.add_argument( "--password", type=str, default=None, help="Password of mongodb." ) + parser.add_argument( + "--limit", + type=Optional[int], + default=None, + help="Limit the number of islands to process.", + ) args = parser.parse_args() @@ -500,6 +487,7 @@ def cli(): yanda=args.yanda, yanda_img=args.yanda_image, stokeslist=args.stokeslist, + limit=args.limit, ) diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index a2d92fac..d51c689f 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -30,7 +30,6 @@ from arrakis.utils.pipeline import logo_str, performance_report_prefect # Defining tasks -linmos_task = task(linmos.main, name="LINMOS") frion_task = task(frion.main, name="FRion") cleanup_task = task(cleanup.main, name="Clean up") rmsynth_task = task(rmsynth_oncuts.main, name="RM Synthesis") @@ -70,7 +69,7 @@ def process_spice(args, host: str) -> None: ) previous_future = ( - linmos_task.submit( + linmos.main( field=args.field, datadir=Path(args.outdir), host=host, @@ -81,8 +80,7 @@ def process_spice(args, host: str) -> None: yanda=args.yanda, yanda_img=args.yanda_image, stokeslist=["I", "Q", "U"], - verbose=True, - wait_for=[previous_future], + limit=args.limit, ) if not args.skip_linmos else previous_future From 0b129310e8337d477df820c0c5735f5c4f91e3fc Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 18:17:14 +1100 Subject: [PATCH 14/46] Cleanup --- arrakis/linmos.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/arrakis/linmos.py b/arrakis/linmos.py index b2a9de48..71b143b0 100644 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -342,14 +342,6 @@ def main( big_beams = big_beams[:limit] comps = comps[:limit] - # parfiles = component_worker.map( - # beams=big_beams, - # comp=comps, - # stokeslist=unmapped(stokeslist), - # field=unmapped(field), - # cutdir=unmapped(cutdir), - # holofile=unmapped(holofile), - # ) all_parfiles = [] for stoke in stokeslist: image_paths = find_images.map( From b886e559435e961529d7be1067db86fd7fded860 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 18:22:18 +1100 Subject: [PATCH 15/46] Fix cleanup --- arrakis/cleanup.py | 37 +++++++++++++++---------------------- arrakis/process_spice.py | 4 +--- 2 files changed, 16 insertions(+), 25 deletions(-) diff --git a/arrakis/cleanup.py b/arrakis/cleanup.py index e15d3431..c48177b7 100644 --- a/arrakis/cleanup.py +++ b/arrakis/cleanup.py @@ -8,6 +8,7 @@ from dask import delayed from dask.distributed import Client, LocalCluster +from prefect import flow, task, unmapped from arrakis.logger import logger from arrakis.utils.pipeline import chunk_dask, logo_str @@ -15,26 +16,29 @@ logger.setLevel(logging.INFO) -@delayed -def cleanup(workdir: str, stoke: str) -> None: +@task(name="Cleanup directory") +def cleanup(workdir: str, stokeslist: List[str]) -> None: """Clean up beam images Args: workdir (str): Directory containing images stoke (str): Stokes parameter """ - # Clean up beam images - # old_files = glob(f"{workdir}/*.cutout.*.{stoke.lower()}.*beam[00-36]*.fits") - # for old in old_files: - # os.remove(old) + if os.path.basename(workdir) == "slurmFiles": + return + for stoke in stokeslist: + # Clean up beam images + # old_files = glob(f"{workdir}/*.cutout.*.{stoke.lower()}.*beam[00-36]*.fits") + # for old in old_files: + # os.remove(old) - pass + ... +@flow(name="Cleanup") def main( datadir: Path, stokeslist: Union[List[str], None] = None, - verbose=True, ) -> None: """Clean up beam images @@ -55,20 +59,9 @@ def main( if os.path.isdir(os.path.join(cutdir, name)) ] ) - - outputs = [] - for file in files: - if os.path.basename(file) == "slurmFiles": - continue - for stoke in stokeslist: - output = cleanup(file, stoke) - outputs.append(output) - - futures = chunk_dask( - outputs=outputs, - task_name="cleanup", - progress_text="Running cleanup", - verbose=verbose, + outputs = cleanup.map( + workdir=files, + stokeslist=unmapped(stokeslist), ) logger.info("Cleanup done!") diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index d51c689f..b59f38d2 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -87,11 +87,9 @@ def process_spice(args, host: str) -> None: ) previous_future = ( - cleanup_task.submit( + cleanup.main( datadir=args.outdir, stokeslist=["I", "Q", "U"], - verbose=True, - wait_for=[previous_future], ) if not args.skip_cleanup else previous_future From 4f181c28f7e26ef0d376d17fff853c5f159f030a Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 18:42:31 +1100 Subject: [PATCH 16/46] Fix FRION --- arrakis/frion.py | 104 +++++++++++++++++++++------------------ arrakis/process_spice.py | 1 - 2 files changed, 56 insertions(+), 49 deletions(-) diff --git a/arrakis/frion.py b/arrakis/frion.py index 5d185f25..e9ece24b 100644 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -7,7 +7,9 @@ from pathlib import Path from pprint import pformat from shutil import copyfile -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict +from typing import NamedTuple as Struct +from typing import Optional, Tuple, Union import astropy.units as u import dask @@ -17,6 +19,7 @@ from dask import delayed from dask.distributed import Client, LocalCluster from FRion import correct, predict +from prefect import flow, task, unmapped from arrakis.logger import logger from arrakis.utils.database import get_db, get_field_db, test_db @@ -27,9 +30,16 @@ logger.setLevel(logging.INFO) -@delayed +class Prediction(Struct): + """FRion prediction""" + + predict_file: str + update: pymongo.UpdateOne + + +@task(name="FRion correction") def correct_worker( - beam: Dict, outdir: str, field: str, predict_file: str, island_id: str + beam: Dict, outdir: str, field: str, prediction: Prediction, island: dict ) -> pymongo.UpdateOne: """Apply FRion corrections to a single island @@ -43,6 +53,8 @@ def correct_worker( Returns: pymongo.UpdateOne: Pymongo update query """ + predict_file = prediction.predict_file + island_id = island["Source_ID"] qfile = os.path.join(outdir, beam["beams"][field]["q_file"]) ufile = os.path.join(outdir, beam["beams"][field]["u_file"]) @@ -67,7 +79,7 @@ def correct_worker( return pymongo.UpdateOne(myquery, newvalues) -@delayed(nout=2) +@task(name="FRion predction") def predict_worker( island: Dict, field: str, @@ -82,7 +94,7 @@ def predict_worker( formatter: Optional[Union[str, Callable]] = None, proxy_server: Optional[str] = None, pre_download: bool = False, -) -> Tuple[str, pymongo.UpdateOne]: +) -> Prediction: """Make FRion prediction for a single island Args: @@ -163,9 +175,18 @@ def predict_worker( update = pymongo.UpdateOne(myquery, newvalues) - return predict_file, update + return Prediction(predict_file, update) + + +@task(name="Index beams") +def index_beams(island: dict, beams: list[dict]) -> dict: + island_id = island["Source_ID"] + beam_idx = [i for i, b in enumerate(beams) if b["Source_ID"] == island_id][0] + beam = beams[beam_idx] + return beam +@flow(name="FRion") def main( field: str, outdir: Path, @@ -240,59 +261,46 @@ def main( os.path.join(cutdir, f"{beams[0]['beams'][f'{field}']['q_file']}"), ) # Type: u.Quantity - # Loop over islands in parallel - outputs = [] - updates_arrays = [] for island in islands: island_id = island["Source_ID"] beam_idx = [i for i, b in enumerate(beams) if b["Source_ID"] == island_id][0] beam = beams[beam_idx] - # Get FRion predictions - predict_file, update = predict_worker( - island=island, - field=field, - beam=beam, - start_time=start_time, - end_time=end_time, - freq=freq.to(u.Hz).value, - cutdir=cutdir, - plotdir=plotdir, - server=ionex_server, - prefix=ionex_prefix, - proxy_server=ionex_proxy_server, - formatter=ionex_formatter, - pre_download=ionex_predownload, - ) - updates_arrays.append(update) - # Apply FRion predictions - output = correct_worker( - beam=beam, - outdir=cutdir, - field=field, - predict_file=predict_file, - island_id=island_id, - ) - outputs.append(output) - # Wait for IONEX data I guess... - _ = outputs[0].compute() - time.sleep(10) - # Execute - futures, future_arrays = dask.persist(outputs, updates_arrays) - # dumb solution for https://github.com/dask/distributed/issues/4831 - time.sleep(10) - tqdm_dask( - futures, desc="Running FRion", disable=(not verbose), total=len(islands) * 3 + + beams_cor = index_beams.map(island=islands, beams=unmapped(beams)) + + predictions = predict_worker.map( + island=islands, + field=unmapped(field), + beam=beams_cor, + start_time=unmapped(start_time), + end_time=unmapped(end_time), + freq=unmapped(freq.to(u.Hz).value), + cutdir=unmapped(cutdir), + plotdir=unmapped(plotdir), + server=unmapped(ionex_server), + prefix=unmapped(ionex_prefix), + proxy_server=unmapped(ionex_proxy_server), + formatter=unmapped(ionex_formatter), + pre_download=unmapped(ionex_predownload), + ) + + corrections = correct_worker.map( + beam=beams_cor, + outdir=unmapped(cutdir), + field=unmapped(field), + predict_file=predictions, + island_id=islands, ) + + updates_arrays = [p.result().update for p in predictions] + updates = [c.result() for c in corrections] if database: logger.info("Updating beams database...") - updates = [f.compute() for f in futures] db_res = beams_col.bulk_write(updates, ordered=False) logger.info(pformat(db_res.bulk_api_result)) logger.info("Updating island database...") - updates_arrays_cmp = [f.compute() for f in future_arrays] - - db_res = island_col.bulk_write(updates_arrays_cmp, ordered=False) + db_res = island_col.bulk_write(updates_arrays, ordered=False) logger.info(pformat(db_res.bulk_api_result)) diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index b59f38d2..b51eb4ce 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -31,7 +31,6 @@ # Defining tasks frion_task = task(frion.main, name="FRion") -cleanup_task = task(cleanup.main, name="Clean up") rmsynth_task = task(rmsynth_oncuts.main, name="RM Synthesis") rmclean_task = task(rmclean_oncuts.main, name="RM-CLEAN") cat_task = task(makecat.main, name="Catalogue") From 66b6ffea649e6d5e841de86333d81279a120147b Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 18:43:20 +1100 Subject: [PATCH 17/46] Fix frion --- arrakis/frion.py | 2 +- arrakis/process_spice.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/arrakis/frion.py b/arrakis/frion.py index e9ece24b..5974bc3a 100644 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -179,7 +179,7 @@ def predict_worker( @task(name="Index beams") -def index_beams(island: dict, beams: list[dict]) -> dict: +def index_beams(island: dict, beams: List[dict]) -> dict: island_id = island["Source_ID"] beam_idx = [i for i, b in enumerate(beams) if b["Source_ID"] == island_id][0] beam = beams[beam_idx] diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index b51eb4ce..6584bed1 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -30,7 +30,6 @@ from arrakis.utils.pipeline import logo_str, performance_report_prefect # Defining tasks -frion_task = task(frion.main, name="FRion") rmsynth_task = task(rmsynth_oncuts.main, name="RM Synthesis") rmclean_task = task(rmclean_oncuts.main, name="RM-CLEAN") cat_task = task(makecat.main, name="Catalogue") @@ -95,7 +94,7 @@ def process_spice(args, host: str) -> None: ) previous_future = ( - frion_task.submit( + frion.main( field=args.field, outdir=args.outdir, host=host, @@ -108,7 +107,6 @@ def process_spice(args, host: str) -> None: ionex_proxy_server=args.ionex_proxy_server, ionex_formatter=args.ionex_formatter, ionex_predownload=args.ionex_predownload, - wait_for=[previous_future], ) if not args.skip_frion else previous_future From 4fe08f29cbba19c1f486edd2587853335ba2305a Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 18:46:49 +1100 Subject: [PATCH 18/46] No verb --- arrakis/frion.py | 1 - arrakis/process_spice.py | 1 - 2 files changed, 2 deletions(-) diff --git a/arrakis/frion.py b/arrakis/frion.py index 5974bc3a..18f87bf7 100644 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -195,7 +195,6 @@ def main( username: Optional[str] = None, password: Optional[str] = None, database=False, - verbose=True, ionex_server: str = "ftp://ftp.aiub.unibe.ch/CODE/", ionex_prefix: str = "codg", ionex_proxy_server: Optional[str] = None, diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index 6584bed1..d609e109 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -102,7 +102,6 @@ def process_spice(args, host: str) -> None: username=args.username, password=args.password, database=args.database, - verbose=args.verbose, ionex_server=args.ionex_server, ionex_proxy_server=args.ionex_proxy_server, ionex_formatter=args.ionex_formatter, From ee50491790bf54d0a4dc525f70fb1712674d7ba3 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 18:50:11 +1100 Subject: [PATCH 19/46] Ruff --- arrakis/cleanup.py | 3 +-- arrakis/cutout.py | 3 +-- arrakis/frion.py | 9 +++------ arrakis/linmos.py | 4 +--- arrakis/rmsynth_oncuts.py | 3 ++- arrakis/utils/database.py | 1 - 6 files changed, 8 insertions(+), 15 deletions(-) diff --git a/arrakis/cleanup.py b/arrakis/cleanup.py index c48177b7..d9c9d9b5 100644 --- a/arrakis/cleanup.py +++ b/arrakis/cleanup.py @@ -6,12 +6,11 @@ from pathlib import Path from typing import List, Union -from dask import delayed from dask.distributed import Client, LocalCluster from prefect import flow, task, unmapped from arrakis.logger import logger -from arrakis.utils.pipeline import chunk_dask, logo_str +from arrakis.utils.pipeline import logo_str logger.setLevel(logging.INFO) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index 3ba47ba2..906f9906 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -20,7 +20,6 @@ from astropy.utils import iers from astropy.utils.exceptions import AstropyWarning from astropy.wcs.utils import skycoord_to_pixel -from dask import delayed from dask.distributed import Client, LocalCluster from distributed import get_client from prefect import flow, task, unmapped @@ -31,7 +30,7 @@ from arrakis.utils.database import get_db, test_db from arrakis.utils.fitsutils import fix_header from arrakis.utils.io import try_mkdir -from arrakis.utils.pipeline import chunk_dask, logo_str, tqdm_dask +from arrakis.utils.pipeline import logo_str iers.conf.auto_download = False warnings.filterwarnings( diff --git a/arrakis/frion.py b/arrakis/frion.py index 18f87bf7..8bebb9c7 100644 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -2,21 +2,18 @@ """Correct for the ionosphere in parallel""" import logging import os -import time from glob import glob from pathlib import Path from pprint import pformat from shutil import copyfile -from typing import Callable, Dict +from typing import Callable, Dict, List from typing import NamedTuple as Struct -from typing import Optional, Tuple, Union +from typing import Optional, Union import astropy.units as u -import dask import numpy as np import pymongo from astropy.time import Time, TimeDelta -from dask import delayed from dask.distributed import Client, LocalCluster from FRion import correct, predict from prefect import flow, task, unmapped @@ -25,7 +22,7 @@ from arrakis.utils.database import get_db, get_field_db, test_db from arrakis.utils.fitsutils import getfreq from arrakis.utils.io import try_mkdir -from arrakis.utils.pipeline import logo_str, tqdm_dask +from arrakis.utils.pipeline import logo_str logger.setLevel(logging.INFO) diff --git a/arrakis/linmos.py b/arrakis/linmos.py index 71b143b0..17dfd8d0 100644 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -9,11 +9,10 @@ from pprint import pformat from typing import Dict, List from typing import NamedTuple as Struct -from typing import Optional, Tuple, Union +from typing import Optional import pymongo from astropy.utils.exceptions import AstropyWarning -from dask import delayed from dask.distributed import Client, LocalCluster from prefect import flow, task, unmapped from racs_tools import beamcon_3D @@ -22,7 +21,6 @@ from arrakis.logger import logger from arrakis.utils.database import get_db, test_db -from arrakis.utils.pipeline import chunk_dask warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) warnings.simplefilter("ignore", category=AstropyWarning) diff --git a/arrakis/rmsynth_oncuts.py b/arrakis/rmsynth_oncuts.py index 12de66f5..8682fddb 100644 --- a/arrakis/rmsynth_oncuts.py +++ b/arrakis/rmsynth_oncuts.py @@ -25,6 +25,7 @@ from astropy.wcs.utils import proj_plane_pixel_scales from dask import delayed from dask.distributed import Client, LocalCluster +from prefect import task from radio_beam import Beam from RMtools_1D import do_RMsynth_1D from RMtools_3D import do_RMsynth_3D @@ -85,7 +86,7 @@ class StokesIFitResult(Struct): """The dictionary of the fit results""" -@delayed +@task(name="3D RM-synthesis") def rmsynthoncut3d( island_id: str, beam: dict, diff --git a/arrakis/utils/database.py b/arrakis/utils/database.py index 8c9b4a40..af1ac864 100644 --- a/arrakis/utils/database.py +++ b/arrakis/utils/database.py @@ -7,7 +7,6 @@ import pymongo from astropy.utils.exceptions import AstropyWarning from pymongo.collection import Collection -from pymongo.database import Database from spectral_cube.utils import SpectralCubeWarning from arrakis.logger import logger From da1416d96c1812cf32bf393bf9526a772debfa63 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Mon, 11 Dec 2023 19:03:39 +1100 Subject: [PATCH 20/46] Make flow happy --- arrakis/process_spice.py | 6 +-- arrakis/rmsynth_oncuts.py | 84 +++++++++++++++------------------------ 2 files changed, 34 insertions(+), 56 deletions(-) diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index d609e109..ff2b507a 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -30,7 +30,6 @@ from arrakis.utils.pipeline import logo_str, performance_report_prefect # Defining tasks -rmsynth_task = task(rmsynth_oncuts.main, name="RM Synthesis") rmclean_task = task(rmclean_oncuts.main, name="RM-CLEAN") cat_task = task(makecat.main, name="Catalogue") @@ -112,7 +111,7 @@ def process_spice(args, host: str) -> None: ) previous_future = ( - rmsynth_task.submit( + rmsynth_oncuts.main( field=args.field, outdir=args.outdir, host=host, @@ -122,7 +121,7 @@ def process_spice(args, host: str) -> None: dimension=args.dimension, verbose=args.verbose, database=args.database, - validate=args.validate, + do_validate=args.validate, limit=args.limit, savePlots=args.savePlots, weightType=args.weightType, @@ -141,7 +140,6 @@ def process_spice(args, host: str) -> None: tt1=args.tt1, ion=True, do_own_fit=args.do_own_fit, - wait_for=[previous_future], ) if not args.skip_rmsynth else previous_future diff --git a/arrakis/rmsynth_oncuts.py b/arrakis/rmsynth_oncuts.py index 8682fddb..ce9ecb80 100644 --- a/arrakis/rmsynth_oncuts.py +++ b/arrakis/rmsynth_oncuts.py @@ -25,7 +25,7 @@ from astropy.wcs.utils import proj_plane_pixel_scales from dask import delayed from dask.distributed import Client, LocalCluster -from prefect import task +from prefect import flow, task, unmapped from radio_beam import Beam from RMtools_1D import do_RMsynth_1D from RMtools_3D import do_RMsynth_3D @@ -101,7 +101,7 @@ def rmsynthoncut3d( not_RMSF: bool = False, rm_verbose: bool = False, ion: bool = False, -): +) -> pymongo.UpdateOne: """3D RM-synthesis Args: @@ -457,7 +457,7 @@ def update_rmtools_dict( mDict[f"fit_flag_{key}"] = val -@delayed +@task(name="1D RM-synthesis") def rmsynthoncut1d( comp: dict, beam: dict, @@ -694,7 +694,6 @@ def rmsynthoncut1d( return pymongo.UpdateOne(myquery, newvalues) -@delayed def rmsynthoncut_i( comp_id: str, outdir: str, @@ -831,6 +830,7 @@ def rmsynthoncut_i( do_RMsynth_1D.saveOutput(mDict, aDict, prefix, verbose=verbose) +@flow(name="RMsynth on cutouts") def main( field: str, outdir: Path, @@ -841,7 +841,7 @@ def main( dimension: str = "1d", verbose: bool = True, database: bool = False, - validate: bool = False, + do_validate: bool = False, limit: Union[int, None] = None, savePlots: bool = False, weightType: str = "variance", @@ -915,6 +915,7 @@ def main( n_island = limit island_ids = island_ids[:limit] component_ids = component_ids[:limit] + components = components.iloc[:limit] # Make frequency file freq, freqfile = getfreq( @@ -926,7 +927,7 @@ def main( outputs = [] - if validate: + if do_validate: logger.info(f"Running RMsynth on {n_comp} components") # We don't run this in parallel! for i, comp_id in enumerate(component_ids): @@ -944,44 +945,32 @@ def main( verbose=verbose, rm_verbose=rm_verbose, ) - output.compute() elif dimension == "1d": logger.info(f"Running RMsynth on {n_comp} components") - for i, (_, comp) in tqdm( - enumerate(components.iterrows()), - total=n_comp, - disable=(not verbose), - desc="Constructing 1D RMsynth jobs", - ): - if i > n_comp + 1: - break - else: - beam = dict(beams.loc[comp["Source_ID"]]) - output = rmsynthoncut1d( - comp=comp, - beam=beam, - outdir=outdir, - freq=freq, - field=field, - polyOrd=polyOrd, - phiMax_radm2=phiMax_radm2, - dPhi_radm2=dPhi_radm2, - nSamples=nSamples, - weightType=weightType, - fitRMSF=fitRMSF, - noStokesI=noStokesI, - showPlots=showPlots, - savePlots=savePlots, - debug=debug, - rm_verbose=rm_verbose, - fit_function=fit_function, - tt0=tt0, - tt1=tt1, - ion=ion, - do_own_fit=do_own_fit, - ) - outputs.append(output) + outputs = rmsynthoncut1d.map( + comp=components.iterrows(), + beam=beams.loc[components["Source_ID"]], + outdir=unmapped(outdir), + freq=unmapped(freq), + field=unmapped(field), + polyOrd=unmapped(polyOrd), + phiMax_radm2=unmapped(phiMax_radm2), + dPhi_radm2=unmapped(dPhi_radm2), + nSamples=unmapped(nSamples), + weightType=unmapped(weightType), + fitRMSF=unmapped(fitRMSF), + noStokesI=unmapped(noStokesI), + showPlots=unmapped(showPlots), + savePlots=unmapped(savePlots), + debug=unmapped(debug), + rm_verbose=unmapped(rm_verbose), + fit_function=unmapped(fit_function), + tt0=unmapped(tt0), + tt1=unmapped(tt1), + ion=unmapped(ion), + do_own_fit=unmapped(do_own_fit), + ) elif dimension == "3d": logger.info(f"Running RMsynth on {n_island} islands") @@ -1010,18 +999,9 @@ def main( else: raise ValueError("An incorrect RMSynth mode has been configured. ") - futures = chunk_dask( - outputs=outputs, - task_name="RMsynth", - progress_text="Running RMsynth", - verbose=verbose, - ) - if database: logger.info("Updating database...") - updates = [f.compute() for f in futures] - # Remove None values - updates = [u for u in updates if u is not None] + updates = [u.result() for u in outputs if u.result() is not None] logger.info("Sending updates to database...") if dimension == "1d": db_res = comp_col.bulk_write(updates, ordered=False) @@ -1254,7 +1234,7 @@ def cli(): dimension=args.dimension, verbose=verbose, database=args.database, - validate=args.validate, + do_validate=args.validate, limit=args.limit, savePlots=args.savePlots, weightType=args.weightType, From e4b54823aaa1fbce27231c5d7a2983cddddc471e Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Tue, 12 Dec 2023 14:25:32 +1100 Subject: [PATCH 21/46] aslkdfjds --- arrakis/frion.py | 30 ++++++++++++++++++++---------- arrakis/process_spice.py | 8 ++++---- arrakis/rmsynth_oncuts.py | 10 ++++++---- environment.yml | 6 +++--- pyproject.toml | 2 +- 5 files changed, 34 insertions(+), 22 deletions(-) diff --git a/arrakis/frion.py b/arrakis/frion.py index 8bebb9c7..109ea6e7 100644 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -11,6 +11,7 @@ from typing import Optional, Union import astropy.units as u +import matplotlib.pyplot as plt import numpy as np import pymongo from astropy.time import Time, TimeDelta @@ -144,11 +145,15 @@ def predict_worker( predict_file = os.path.join(i_dir, f"{iname}_ion.txt") predict.write_modulation(freq_array=freq, theta=theta, filename=predict_file) - plot_file = os.path.join(i_dir, f"{iname}_ion.pdf") - predict.generate_plots( - times, RMs, theta, freq, position=[ra, dec], savename=plot_file - ) - plot_files = glob(os.path.join(i_dir, "*ion.pdf")) + plot_file = os.path.join(i_dir, f"{iname}_ion.png") + try: + predict.generate_plots( + times, RMs, theta, freq, position=[ra, dec], savename=plot_file + ) + except Exception as e: + logger.error(f"Failed to generate plot: {e}") + + plot_files = glob(os.path.join(i_dir, "*ion.png")) logger.info(f"Plotting files: {plot_files=}") for src in plot_files: base = os.path.basename(src) @@ -197,6 +202,7 @@ def main( ionex_proxy_server: Optional[str] = None, ionex_formatter: Optional[Union[str, Callable]] = "ftp.aiub.unibe.ch", ionex_predownload: bool = False, + limit: Optional[int] = None, ): """Main script @@ -255,14 +261,18 @@ def main( freq = getfreq( os.path.join(cutdir, f"{beams[0]['beams'][f'{field}']['q_file']}"), - ) # Type: u.Quantity + ) + if limit is not None: + logger.info(f"Limiting to {limit} islands") + islands = islands[:limit] + + beams_cor = [] for island in islands: island_id = island["Source_ID"] beam_idx = [i for i, b in enumerate(beams) if b["Source_ID"] == island_id][0] beam = beams[beam_idx] - - beams_cor = index_beams.map(island=islands, beams=unmapped(beams)) + beams_cor.append(beam) predictions = predict_worker.map( island=islands, @@ -284,8 +294,8 @@ def main( beam=beams_cor, outdir=unmapped(cutdir), field=unmapped(field), - predict_file=predictions, - island_id=islands, + prediction=predictions, + island=islands, ) updates_arrays = [p.result().update for p in predictions] diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index ff2b507a..31fa871e 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -45,6 +45,8 @@ def process_spice(args, host: str) -> None: # TODO: Fix the type assigned to args. The `configargparse.Namespace` was causing issues # with the pydantic validation used by prefect / flow. + outfile = f"{args.field}.pipe.test.fits" if args.outfile is None else args.outfile + with get_dask_client(): previous_future = None previous_future = ( @@ -105,6 +107,7 @@ def process_spice(args, host: str) -> None: ionex_proxy_server=args.ionex_proxy_server, ionex_formatter=args.ionex_formatter, ionex_predownload=args.ionex_predownload, + limit=args.limit, ) if not args.skip_frion else previous_future @@ -178,7 +181,7 @@ def process_spice(args, host: str) -> None: username=args.username, password=args.password, verbose=args.verbose, - outfile=args.outfile, + outfile=outfile, wait_for=[previous_future], ) if not args.skip_cat @@ -300,9 +303,6 @@ def main(args: configargparse.Namespace) -> None: password=args.password, ) - if args.outfile is None: - outfile = f"{args.field}.pipe.test.fits" - if not args.skip_imager: # This is the client for the imager component of the arrakis # pipeline. diff --git a/arrakis/rmsynth_oncuts.py b/arrakis/rmsynth_oncuts.py index ce9ecb80..d64783a7 100644 --- a/arrakis/rmsynth_oncuts.py +++ b/arrakis/rmsynth_oncuts.py @@ -459,8 +459,8 @@ def update_rmtools_dict( @task(name="1D RM-synthesis") def rmsynthoncut1d( - comp: dict, - beam: dict, + comp_tuple: Tuple[str, pd.Series], + beams: pd.DataFrame, outdir: str, freq: np.ndarray, field: str, @@ -503,6 +503,8 @@ def rmsynthoncut1d( rm_verbose (bool, optional): Verbose RMsynth. Defaults to False. """ logger.setLevel(logging.INFO) + comp = comp_tuple[1] + beam = dict(beams.loc[comp["Source_ID"]]) iname = comp["Source_ID"] cname = comp["Gaussian_ID"] @@ -949,8 +951,8 @@ def main( elif dimension == "1d": logger.info(f"Running RMsynth on {n_comp} components") outputs = rmsynthoncut1d.map( - comp=components.iterrows(), - beam=beams.loc[components["Source_ID"]], + comp_tuple=components.iterrows(), + beams=unmapped(beams), outdir=unmapped(outdir), freq=unmapped(freq), field=unmapped(field), diff --git a/environment.yml b/environment.yml index e09e83a9..274a6f16 100644 --- a/environment.yml +++ b/environment.yml @@ -1,17 +1,17 @@ -name: arrakis +name: arrakis310 channels: - astropy - conda-forge - defaults - pkgw-forge dependencies: -- python=3.8 +- python=3.10 - pip - future - numpy - scipy - pandas -- matplotlib +- matplotlib>=3.8 - setuptools - ipython - astropy>=4.3 diff --git a/pyproject.toml b/pyproject.toml index e48aa3ff..6fe2199b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dask_mpi = "*" FRion = {git = "https://github.com/CIRADA-Tools/FRion.git" } h5py = "*" ipython = "*" -matplotlib = "*" +matplotlib = "^3.8" numba = "*" numba_progress = "*" #mpi4py = "*" From 26a5209bc1ea013cd278a2f4287b71ab9235da92 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Tue, 12 Dec 2023 14:29:07 +1100 Subject: [PATCH 22/46] Update 3D --- arrakis/rmsynth_oncuts.py | 45 ++++++++++++++++----------------------- 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/arrakis/rmsynth_oncuts.py b/arrakis/rmsynth_oncuts.py index d64783a7..e61ff458 100644 --- a/arrakis/rmsynth_oncuts.py +++ b/arrakis/rmsynth_oncuts.py @@ -89,7 +89,7 @@ class StokesIFitResult(Struct): @task(name="3D RM-synthesis") def rmsynthoncut3d( island_id: str, - beam: dict, + beams: pd.DataFrame, outdir: str, freq: np.ndarray, field: str, @@ -118,7 +118,7 @@ def rmsynthoncut3d( not_RMSF (bool, optional): Skip calculation of RMSF. Defaults to False. rm_verbose (bool, optional): Verbose RMsynth. Defaults to False. """ - + beam = dict(beams.loc[island_id]) iname = island_id ifile = os.path.join(outdir, beam["beams"][field]["i_file"]) @@ -927,13 +927,11 @@ def main( ) freq = np.array(freq) - outputs = [] - if do_validate: logger.info(f"Running RMsynth on {n_comp} components") # We don't run this in parallel! for i, comp_id in enumerate(component_ids): - output = rmsynthoncut_i( + _ = rmsynthoncut_i( comp_id=comp_id, outdir=outdir, freq=freq, @@ -976,28 +974,21 @@ def main( elif dimension == "3d": logger.info(f"Running RMsynth on {n_island} islands") - - for i, island_id in enumerate(island_ids): - if i > n_island + 1: - break - else: - beam = dict(beams.loc[island_id]) - output = rmsynthoncut3d( - island_id=island_id, - beam=beam, - outdir=outdir, - freq=freq, - field=field, - phiMax_radm2=phiMax_radm2, - dPhi_radm2=dPhi_radm2, - nSamples=nSamples, - weightType=weightType, - fitRMSF=fitRMSF, - not_RMSF=not_RMSF, - rm_verbose=rm_verbose, - ion=ion, - ) - outputs.append(output) + outputs = rmsynthoncut3d.map( + island_id=island_ids, + beams=unmapped(beams), + outdir=unmapped(outdir), + freq=unmapped(freq), + field=unmapped(field), + phiMax_radm2=unmapped(phiMax_radm2), + dPhi_radm2=unmapped(dPhi_radm2), + nSamples=unmapped(nSamples), + weightType=unmapped(weightType), + fitRMSF=unmapped(fitRMSF), + not_RMSF=unmapped(not_RMSF), + rm_verbose=unmapped(rm_verbose), + ion=unmapped(ion), + ) else: raise ValueError("An incorrect RMSynth mode has been configured. ") From b80c363632081a72b84ee7300be611d0a9c6c26b Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Tue, 12 Dec 2023 14:38:39 +1100 Subject: [PATCH 23/46] Add RM-CLEAN --- arrakis/process_spice.py | 5 +-- arrakis/rmclean_oncuts.py | 86 ++++++++++++++------------------------- 2 files changed, 32 insertions(+), 59 deletions(-) diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index 31fa871e..ea20512e 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -30,7 +30,6 @@ from arrakis.utils.pipeline import logo_str, performance_report_prefect # Defining tasks -rmclean_task = task(rmclean_oncuts.main, name="RM-CLEAN") cat_task = task(makecat.main, name="Catalogue") @@ -149,7 +148,7 @@ def process_spice(args, host: str) -> None: ) previous_future = ( - rmclean_task.submit( + rmclean_oncuts.main( field=args.field, outdir=args.outdir, host=host, @@ -157,9 +156,7 @@ def process_spice(args, host: str) -> None: username=args.username, password=args.password, dimension=args.dimension, - verbose=args.verbose, database=args.database, - validate=args.validate, limit=args.limit, cutoff=args.cutoff, maxIter=args.maxIter, diff --git a/arrakis/rmclean_oncuts.py b/arrakis/rmclean_oncuts.py index f8502177..664fd716 100644 --- a/arrakis/rmclean_oncuts.py +++ b/arrakis/rmclean_oncuts.py @@ -7,13 +7,13 @@ from pathlib import Path from pprint import pformat from shutil import copyfile -from typing import Union +from typing import Optional import matplotlib.pyplot as plt import numpy as np import pymongo -from dask import delayed from dask.distributed import Client, LocalCluster +from prefect import flow, task, unmapped from RMtools_1D import do_RMclean_1D from RMtools_3D import do_RMclean_3D from tqdm import tqdm @@ -23,7 +23,7 @@ from arrakis.utils.pipeline import chunk_dask, logo_str -@delayed +@task(name="1D RM-CLEAN") def rmclean1d( comp: dict, outdir: str, @@ -135,7 +135,7 @@ def rmclean1d( return pymongo.UpdateOne(myquery, newvalues) -@delayed +@task(name="3D RM-CLEAN") def rmclean3d( island: dict, outdir: str, @@ -203,19 +203,18 @@ def rmclean3d( return pymongo.UpdateOne(myquery, newvalues) +@flow(name="RM-CLEAN on cutouts") def main( field: str, outdir: Path, host: str, epoch: int, - username: Union[str, None] = None, - password: Union[str, None] = None, + username: Optional[str] = None, + password: Optional[str] = None, dimension="1d", - verbose=True, database=False, savePlots=True, - validate=False, - limit: Union[int, None] = None, + limit: Optional[int] = None, cutoff: float = -3, maxIter=10000, gain=0.1, @@ -296,52 +295,37 @@ def main( n_island = count # component_ids = component_ids[:count] - outputs = [] if dimension == "1d": logger.info(f"Running RM-CLEAN on {n_comp} components") - for i, comp in enumerate(tqdm(components, total=n_comp)): - if i > n_comp + 1: - break - else: - output = rmclean1d( - comp=comp, - outdir=outdir, - cutoff=cutoff, - maxIter=maxIter, - gain=gain, - showPlots=showPlots, - savePlots=savePlots, - rm_verbose=rm_verbose, - window=window, - ) - outputs.append(output) - + outputs = rmclean1d.map( + comp=components, + outdir=unmapped(outdir), + cutoff=unmapped(cutoff), + maxIter=unmapped(maxIter), + gain=unmapped(gain), + showPlots=unmapped(showPlots), + savePlots=unmapped(savePlots), + rm_verbose=unmapped(rm_verbose), + window=unmapped(window), + ) elif dimension == "3d": logger.info(f"Running RM-CLEAN on {n_island} islands") - for i, island in enumerate(islands): - if i > n_island + 1: - break - else: - output = rmclean3d( - island=island, - outdir=outdir, - cutoff=cutoff, - maxIter=maxIter, - gain=gain, - rm_verbose=rm_verbose, - ) - outputs.append(output) - futures = chunk_dask( - outputs=outputs, - task_name="RM-CLEAN", - progress_text="Running RM-CLEAN", - verbose=verbose, - ) + outputs = rmclean3d.map( + island=islands, + outdir=unmapped(outdir), + cutoff=unmapped(cutoff), + maxIter=unmapped(maxIter), + gain=unmapped(gain), + rm_verbose=unmapped(rm_verbose), + ) + + else: + raise ValueError(f"Dimension {dimension} not supported.") if database: logger.info("Updating database...") - updates = [f.compute() for f in futures] + updates = [f.result() for f in outputs] if dimension == "1d": db_res = comp_col.bulk_write(updates, ordered=False) logger.info(pformat(db_res.bulk_api_result)) @@ -425,12 +409,6 @@ def cli(): parser.add_argument( "-sp", "--savePlots", action="store_true", help="save the plots [False]." ) - parser.add_argument( - "--validate", - dest="validate", - action="store_true", - help="Run on Stokes I [False].", - ) parser.add_argument( "--limit", @@ -499,10 +477,8 @@ def cli(): username=args.username, password=args.password, dimension=args.dimension, - verbose=verbose, database=args.database, savePlots=args.savePlots, - validate=args.validate, limit=args.limit, cutoff=args.cutoff, maxIter=args.maxIter, From 4e373ffe1d84efb52763205ce6d6a0e9c8f02980 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Tue, 12 Dec 2023 14:51:37 +1100 Subject: [PATCH 24/46] Add cats --- arrakis/makecat.py | 54 +++++++++------------------------------- arrakis/process_spice.py | 5 +--- 2 files changed, 13 insertions(+), 46 deletions(-) diff --git a/arrakis/makecat.py b/arrakis/makecat.py index 9ecae793..4fbf03ff 100644 --- a/arrakis/makecat.py +++ b/arrakis/makecat.py @@ -17,6 +17,7 @@ from astropy.stats import sigma_clip from astropy.table import Column, Table from dask.diagnostics import ProgressBar +from prefect import flow, task, unmapped from rmtable import RMTable from scipy.stats import lognorm, norm from tqdm import tqdm @@ -218,6 +219,7 @@ def lognorm_from_percentiles(x1, p1, x2, p2): return scale, np.exp(mean) +@task(name="Fix sigma_add") def sigma_add_fix(tab): sigma_Q_low = np.array(tab["sigma_add_Q"] - tab["sigma_add_Q_err_minus"]) sigma_Q_high = np.array(tab["sigma_add_Q"] + tab["sigma_add_Q_err_plus"]) @@ -292,7 +294,7 @@ def get_fit_func( degree: int = 2, do_plot: bool = False, high_snr_cut: float = 30.0, -): +) -> Tuple[np.polynomial.Polynomial.fit, plt.Figure]: """Fit an envelope to define leakage sources Args: @@ -313,6 +315,10 @@ def get_fit_func( logger.info(f"{np.sum(hi_snr)} sources with Stokes I SNR above {high_snr_cut=}.") + if len(hi_i_tab) < 100: + logger.critcal("Not enough high SNR sources to fit leakage envelope.") + return np.polynomial.Polynomial.fit([0], [0], deg=1), plt.figure() + # Get fractional pol frac_P = np.array(hi_i_tab["fracpol"].value) # Bin sources by separation from tile centre @@ -340,7 +346,6 @@ def get_fit_func( # Plot the fit latexify(columns=2) - figure = plt.figure(facecolor="w") fig = plt.figure(facecolor="w") color = "tab:green" stoke = { @@ -360,17 +365,6 @@ def get_fit_func( zorder=0, rasterized=True, ) - - # is_finite = np.logical_and( - # np.isfinite(hi_i_tab["beamdist"].to(u.deg).value), np.isfinite(frac_P) - # ) - # hist2d( - # np.array(hi_i_tab["beamdist"].to(u.deg).value)[is_finite, np.newaxis], - # np.array(frac_P)[is_finite, np.newaxis], - # bins=(nbins, nbins), - # range=[[0, 5], [0, 0.05]], - # # color=color, - # ) plt.plot(bins_c, meds, alpha=1, c=color, label="Median", linewidth=2) for s, ls in zip((1, 2), ("--", ":")): for r in ("ups", "los"): @@ -507,6 +501,7 @@ def masker(x): return cat_out +@task(name="Add cuts and flags") def cuts_and_flags(cat: RMTable) -> RMTable: """Cut out bad sources, and add flag columns @@ -564,6 +559,7 @@ def cuts_and_flags(cat: RMTable) -> RMTable: return cat_out, fit +@task(name="Get spectral indices") def get_alpha(cat): coefs_str = cat["stokesI_model_coef"] coefs_err_str = cat["stokesI_model_coef_err"] @@ -592,6 +588,7 @@ def get_alpha(cat): ) +@task(name="Get integration times") def get_integration_time(cat, field_col): logger.warn("Will be stripping the trailing field character prefix. ") field_names = [ @@ -620,31 +617,6 @@ def get_integration_time(cat, field_col): return np.array(tints) * u.s -# Stolen from GASKAP pipeline -# Credit to J. Dempsey -# https://github.com/GASKAP/GASKAP-HI-Absorption-Pipeline/ -# https://github.com/GASKAP/GASKAP-HI-Absorption-Pipeline/blob/ -# def add_col_metadata(vo_table, col_name, description, units=None, ucd=None, datatype=None): -# """Add metadata to a VO table column. - -# Args: -# vo_table (vot.): VO Table -# col_name (str): Column name -# description (str): Long description of the column -# units (u.Unit, optional): Unit of column. Defaults to None. -# ucd (str, optional): UCD string. Defaults to None. -# datatype (_type_, optional): _description_. Defaults to None. -# """ -# col = vo_table.get_first_table().get_field_by_id(col_name) -# col.description = description -# if units: -# col.unit = units -# if ucd: -# col.ucd = ucd -# if datatype: -# col.datatype = datatype - - def add_metadata(vo_table: vot.tree.Table, filename: str): """Add metadata to VO Table for CASDA @@ -732,6 +704,7 @@ def fix_blank_units(rmtab: TableLike) -> TableLike: return rmtab +@task(name="Write votable") def write_votable(rmtab: TableLike, outfile: str) -> None: # Replace bad column names fix_columns = { @@ -752,6 +725,7 @@ def write_votable(rmtab: TableLike, outfile: str) -> None: replace_nans(outfile) +@flow(name="Make catalogue") def main( field: str, host: str, @@ -862,10 +836,6 @@ def main( alpha_dict = get_alpha(rmtab) rmtab.add_column(Column(data=alpha_dict["alphas"], name="spectral_index")) rmtab.add_column(Column(data=alpha_dict["alphas_err"], name="spectral_index_err")) - # rmtab.add_column(Column(data=alpha_dict["betas"], name="spectral_curvature")) - # rmtab.add_column( - # Column(data=alpha_dict["betas_err"], name="spectral_curvature_err") - # ) # Add integration time field_col = get_field_db( diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index ea20512e..d6987c86 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -29,9 +29,6 @@ from arrakis.utils.database import test_db from arrakis.utils.pipeline import logo_str, performance_report_prefect -# Defining tasks -cat_task = task(makecat.main, name="Catalogue") - @flow(name="Combining+Synthesis on Arrakis") def process_spice(args, host: str) -> None: @@ -171,7 +168,7 @@ def process_spice(args, host: str) -> None: ) previous_future = ( - cat_task.submit( + makecat.main( field=args.field, host=host, epoch=args.epoch, From eed30c0b96775bf5d889fa279ade96d0e24aad26 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Tue, 12 Dec 2023 15:01:27 +1100 Subject: [PATCH 25/46] Fix fitting --- arrakis/makecat.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/arrakis/makecat.py b/arrakis/makecat.py index 4fbf03ff..2bf7c7e6 100644 --- a/arrakis/makecat.py +++ b/arrakis/makecat.py @@ -316,8 +316,13 @@ def get_fit_func( logger.info(f"{np.sum(hi_snr)} sources with Stokes I SNR above {high_snr_cut=}.") if len(hi_i_tab) < 100: - logger.critcal("Not enough high SNR sources to fit leakage envelope.") - return np.polynomial.Polynomial.fit([0], [0], deg=1), plt.figure() + logger.critical("Not enough high SNR sources to fit leakage envelope.") + return ( + np.polynomial.Polynomial.fit( + np.array([0, 1]), np.array([0, 0]), deg=0, full=False + ), + plt.figure(), + ) # Get fractional pol frac_P = np.array(hi_i_tab["fracpol"].value) From be1fe0514ffc507ea4d4babfc276f8519917e1fe Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Tue, 12 Dec 2023 15:02:57 +1100 Subject: [PATCH 26/46] Ruff --- arrakis/frion.py | 1 - arrakis/makecat.py | 2 +- arrakis/process_spice.py | 2 +- arrakis/rmclean_oncuts.py | 3 +-- arrakis/rmsynth_oncuts.py | 4 +--- 5 files changed, 4 insertions(+), 8 deletions(-) diff --git a/arrakis/frion.py b/arrakis/frion.py index 109ea6e7..0ee0f673 100644 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -11,7 +11,6 @@ from typing import Optional, Union import astropy.units as u -import matplotlib.pyplot as plt import numpy as np import pymongo from astropy.time import Time, TimeDelta diff --git a/arrakis/makecat.py b/arrakis/makecat.py index 2bf7c7e6..38f2a191 100644 --- a/arrakis/makecat.py +++ b/arrakis/makecat.py @@ -17,7 +17,7 @@ from astropy.stats import sigma_clip from astropy.table import Column, Table from dask.diagnostics import ProgressBar -from prefect import flow, task, unmapped +from prefect import flow, task from rmtable import RMTable from scipy.stats import lognorm, norm from tqdm import tqdm diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index d6987c86..cdc31945 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -12,7 +12,7 @@ from dask.distributed import Client from dask_jobqueue import SLURMCluster from dask_mpi import initialize -from prefect import flow, task +from prefect import flow from prefect_dask import DaskTaskRunner, get_dask_client from arrakis import ( diff --git a/arrakis/rmclean_oncuts.py b/arrakis/rmclean_oncuts.py index 664fd716..b697ad3d 100644 --- a/arrakis/rmclean_oncuts.py +++ b/arrakis/rmclean_oncuts.py @@ -16,11 +16,10 @@ from prefect import flow, task, unmapped from RMtools_1D import do_RMclean_1D from RMtools_3D import do_RMclean_3D -from tqdm import tqdm from arrakis.logger import logger from arrakis.utils.database import get_db, test_db -from arrakis.utils.pipeline import chunk_dask, logo_str +from arrakis.utils.pipeline import logo_str @task(name="1D RM-CLEAN") diff --git a/arrakis/rmsynth_oncuts.py b/arrakis/rmsynth_oncuts.py index e61ff458..5296b199 100644 --- a/arrakis/rmsynth_oncuts.py +++ b/arrakis/rmsynth_oncuts.py @@ -23,7 +23,6 @@ from astropy.stats import mad_std, sigma_clip from astropy.wcs import WCS from astropy.wcs.utils import proj_plane_pixel_scales -from dask import delayed from dask.distributed import Client, LocalCluster from prefect import flow, task, unmapped from radio_beam import Beam @@ -31,14 +30,13 @@ from RMtools_3D import do_RMsynth_3D from RMutils.util_misc import create_frac_spectra from scipy.stats import norm -from tqdm import tqdm from arrakis.logger import logger from arrakis.utils.database import get_db, test_db from arrakis.utils.fitsutils import getfreq from arrakis.utils.fitting import fit_pl from arrakis.utils.io import try_mkdir -from arrakis.utils.pipeline import chunk_dask, logo_str +from arrakis.utils.pipeline import logo_str logger.setLevel(logging.INFO) From 0f39fc02f65a6ba14d3b257f3a02846bab993566 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Tue, 12 Dec 2023 15:11:56 +1100 Subject: [PATCH 27/46] Fix voronoi fail --- arrakis/makecat.py | 51 +++++++++++++++++++++++++--------------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/arrakis/makecat.py b/arrakis/makecat.py index 38f2a191..71480770 100644 --- a/arrakis/makecat.py +++ b/arrakis/makecat.py @@ -17,7 +17,7 @@ from astropy.stats import sigma_clip from astropy.table import Column, Table from dask.diagnostics import ProgressBar -from prefect import flow, task +from prefect import flow, task, unmapped from rmtable import RMTable from scipy.stats import lognorm, norm from tqdm import tqdm @@ -405,6 +405,15 @@ def compute_local_rm_flag(good_cat: Table, big_cat: Table) -> Table: logger.info("Computing voronoi bins and finding bad RMs") logger.info(f"Number of available sources: {len(good_cat)}.") + df = good_cat.to_pandas() + df.reset_index(inplace=True) + df.set_index("cat_id", inplace=True) + + df_out = big_cat.to_pandas() + df_out.reset_index(inplace=True) + df_out.set_index("cat_id", inplace=True) + df_out["local_rm_flag"] = False + def sn_func(index, signal=None, noise=None): try: sn = len(np.array(index)) @@ -414,6 +423,7 @@ def sn_func(index, signal=None, noise=None): target_sn = 30 target_bins = 6 + fail = True while target_sn > 1: logger.debug( f"Trying to find Voroni bins with RMs per bin={target_sn}, Number of bins={target_bins}" @@ -463,32 +473,27 @@ def sn_func(index, signal=None, noise=None): fail_msg = "Failed to converge towards a Voronoi binning solution. " logger.error(fail_msg) - raise ValueError(fail_msg) + fail = True - logger.info(f"Found {len(set(bin_number))} bins") - df = good_cat.to_pandas() - df.reset_index(inplace=True) - df.set_index("cat_id", inplace=True) - df["bin_number"] = bin_number - # Use sigma clipping to find outliers + if not fail: + logger.info(f"Found {len(set(bin_number))} bins") + df["bin_number"] = bin_number - def masker(x): - return pd.Series( - sigma_clip(x["rm"], sigma=3, maxiters=None, cenfunc=np.median).mask, - index=x.index, + # Use sigma clipping to find outliers + def masker(x): + return pd.Series( + sigma_clip(x["rm"], sigma=3, maxiters=None, cenfunc=np.median).mask, + index=x.index, + ) + + perc_g = df.groupby("bin_number").apply( + masker, ) + # Put flag into the catalogue + df["local_rm_flag"] = perc_g.reset_index().set_index("cat_id")[0] + df.drop(columns=["bin_number"], inplace=True) + df_out.update(df["local_rm_flag"]) - perc_g = df.groupby("bin_number").apply( - masker, - ) - # Put flag into the catalogue - df["local_rm_flag"] = perc_g.reset_index().set_index("cat_id")[0] - df.drop(columns=["bin_number"], inplace=True) - df_out = big_cat.to_pandas() - df_out.reset_index(inplace=True) - df_out.set_index("cat_id", inplace=True) - df_out["local_rm_flag"] = False - df_out.update(df["local_rm_flag"]) df_out["local_rm_flag"] = df_out["local_rm_flag"].astype(bool) cat_out = RMTable.from_pandas(df_out.reset_index()) cat_out["local_rm_flag"].meta["ucd"] = "meta.code" From 55080ab76fce5c6839f4139c210a8baa3f72aca4 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Tue, 12 Dec 2023 15:27:50 +1100 Subject: [PATCH 28/46] Update logs --- arrakis/cutout.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index 906f9906..68e68acd 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -98,8 +98,6 @@ def cutout( """ logger.setLevel(logging.INFO) - logger.info(f"Timwashere - {cutout_args.image=}") - outdir = os.path.abspath(cutout_args.outdir) ret = [] @@ -297,7 +295,6 @@ def get_args( stoke=stoke.lower(), ) ) - logger.info(f"{args=}") return args From a1e6640a1624db2587fc51cf8ddf843df8580094 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Wed, 13 Dec 2023 00:25:11 +1100 Subject: [PATCH 29/46] Fix linmos --- arrakis/linmos.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/arrakis/linmos.py b/arrakis/linmos.py index 17dfd8d0..b45cbecb 100644 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -15,6 +15,7 @@ from astropy.utils.exceptions import AstropyWarning from dask.distributed import Client, LocalCluster from prefect import flow, task, unmapped +from prefect.utilities.annotations import quote from racs_tools import beamcon_3D from spectral_cube.utils import SpectralCubeWarning from spython.main import Client as sclient @@ -357,7 +358,7 @@ def main( all_parfiles.extend(parfiles) results = linmos.map( - all_parfiles, + quote(all_parfiles), unmapped(field), unmapped(str(image)), unmapped(holofile), From 928cbd129f99357d3e58ae65ff1dcb187908efa6 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Wed, 13 Dec 2023 00:25:24 +1100 Subject: [PATCH 30/46] Fix cluster --- arrakis/process_spice.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index cdc31945..91e7b541 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -209,6 +209,7 @@ def create_client( minimum: int = 1, maximum: int = 38, mode: str = "adapt", + overload: bool = False, ) -> Client: logger.info("Creating a Client") if dask_config is None: @@ -220,16 +221,12 @@ def create_client( logger.info(f"Loading {dask_config}") config = yaml.safe_load(f) - logger.info("Overwriting config attributes.") - config["job_cpu"] = config["cores"] - config["cores"] = 1 - config["processes"] = 1 + if overload: + logger.info("Overwriting config attributes.") + config["job_cpu"] = config["cores"] + config["cores"] = 1 + config["processes"] = 1 - # config.update( - # { - # "log_directory": f"{field}_{Time.now().fits}_spice_logs/" - # } - # ) if use_mpi: initialize( interface=config["interface"], @@ -306,6 +303,7 @@ def main(args: configargparse.Namespace) -> None: port_forward=args.port_forward, minimum=1, maximum=38, + overload=True, ) logger.info("Obtained DaskTaskRunner, executing the imager workflow. ") @@ -359,8 +357,8 @@ def main(args: configargparse.Namespace) -> None: dask_config=args.dask_config, use_mpi=args.use_mpi, port_forward=args.port_forward, - minimum=64, - maximum=64, + minimum=1, + maximum=256, ) # Define flow From 96dfa1bb2b6e18e2b5c24d646bb4be5ac1ab157e Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Wed, 13 Dec 2023 13:30:30 +1100 Subject: [PATCH 31/46] Fix linmos --- arrakis/linmos.py | 2 +- arrakis/process_spice.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/arrakis/linmos.py b/arrakis/linmos.py index b45cbecb..165c6990 100644 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -358,7 +358,7 @@ def main( all_parfiles.extend(parfiles) results = linmos.map( - quote(all_parfiles), + all_parfiles, unmapped(field), unmapped(str(image)), unmapped(holofile), diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index 91e7b541..acf01ee0 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -245,9 +245,7 @@ def create_client( cluster.adapt(minimum=minimum, maximum=maximum) elif mode == "scale": cluster.scale(maximum) - # cluster.scale(36) - # cluster = LocalCluster(n_workers=10, processes=True, threads_per_worker=1, local_directory="/dev/shm",dashboard_address=f":{args.port}") client = Client(cluster) port = client.scheduler_info()["services"]["dashboard"] @@ -359,6 +357,7 @@ def main(args: configargparse.Namespace) -> None: port_forward=args.port_forward, minimum=1, maximum=256, + mode="scale", ) # Define flow From a847030254d1bcef7849a10f8568266392460df1 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Wed, 13 Dec 2023 13:50:09 +1100 Subject: [PATCH 32/46] Use config --- arrakis/configs/petrichor.yaml | 52 ++++++++++++++++------------------ arrakis/process_spice.py | 45 ++++++++++++++--------------- arrakis/utils/meta.py | 19 +++++++++++++ 3 files changed, 65 insertions(+), 51 deletions(-) diff --git a/arrakis/configs/petrichor.yaml b/arrakis/configs/petrichor.yaml index 6f4a51d1..9f79823a 100644 --- a/arrakis/configs/petrichor.yaml +++ b/arrakis/configs/petrichor.yaml @@ -1,28 +1,26 @@ # Set up for Petrichor -cores: 8 -processes: 8 -name: 'spice-worker' -memory: "64GiB" -account: 'OD-217087' -#queue: 'workq' -walltime: '0-8:00:00' -job_extra_directives: ['--qos express'] -# interface for the workers -interface: "ib0" -log_directory: 'spice_logs' -job_script_prologue: [ - 'module load singularity', -] -# job_script_prologue: [ -# 'export OMP_NUM_THREADS=1', -# 'source /home/$(whoami)/.bashrc', -# 'conda activate spice' -# ] -# python: 'srun -n 1 -c 64 python' -#worker_extra_args: [ -# "--lifetime", "23h", -# "--lifetime-stagger", "5m", -#] -death_timeout: 1000 -local_directory: $LOCALDIR -silence_logs: 'info' +cluster_class: "dask_jobqueue.SLURMCluster" +cluster_kwargs: + cores: 8 + processes: 8 + name: 'spice-worker' + memory: "64GiB" + account: 'OD-217087' + #queue: 'workq' + walltime: '0-01:00:00' + job_extra_directives: ['--qos express'] + # interface for the workers + interface: "ib0" + log_directory: 'spice_logs' + job_script_prologue: [ + 'module load singularity', + 'unset SINGULARITY_BINDPATH' + ] + local_directory: $LOCALDIR + silence_logs: 'info' +adapt_kwargs: + minimum: 1 + maximum: 36 + wait_count: 20 + target_duration: "300s" + interval: "30s" diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index acf01ee0..a221a658 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -9,8 +9,7 @@ import pkg_resources import yaml from astropy.time import Time -from dask.distributed import Client -from dask_jobqueue import SLURMCluster +from dask.distributed import Client, LocalCluster from dask_mpi import initialize from prefect import flow from prefect_dask import DaskTaskRunner, get_dask_client @@ -27,6 +26,7 @@ ) from arrakis.logger import logger from arrakis.utils.database import test_db +from arrakis.utils.meta import class_for_name from arrakis.utils.pipeline import logo_str, performance_report_prefect @@ -206,9 +206,6 @@ def create_client( dask_config: str, use_mpi: bool, port_forward: Any, - minimum: int = 1, - maximum: int = 38, - mode: str = "adapt", overload: bool = False, ) -> Client: logger.info("Creating a Client") @@ -219,34 +216,39 @@ def create_client( # Following https://github.com/dask/dask-jobqueue/issues/499 with open(dask_config) as f: logger.info(f"Loading {dask_config}") - config = yaml.safe_load(f) + config: dict = yaml.safe_load(f) + + cluster_class_str = config.get("cluster_class", "distributed.LocalCluster") + cluster_kwargs = config.get("cluster_kwargs", {}) + adapt_kwargs = config.get("adapt_kwargs", {}) if overload: logger.info("Overwriting config attributes.") - config["job_cpu"] = config["cores"] - config["cores"] = 1 - config["processes"] = 1 + cluster_kwargs["job_cpu"] = cluster_kwargs["cores"] + cluster_kwargs["cores"] = 1 + cluster_kwargs["processes"] = 1 if use_mpi: initialize( - interface=config["interface"], - local_directory=config["local_directory"], - nthreads=config["cores"] / config["processes"], + interface=cluster_kwargs["interface"], + local_directory=cluster_kwargs["local_directory"], + nthreads=cluster_kwargs["cores"] / cluster_kwargs["processes"], ) client = Client() else: - # TODO: load the cluster type and initialise it from field specified in the loaded config - cluster = SLURMCluster( - **config, + cluster_class = class_for_name( + module_name=cluster_class_str.rsplit(".", 1)[0], + class_name=cluster_class_str.rsplit(".", 1)[1], + ) + cluster = cluster_class( + **cluster_kwargs, ) logger.debug(f"Submitted scripts will look like: \n {cluster.job_script()}") - if mode == "adapt": - cluster.adapt(minimum=minimum, maximum=maximum) - elif mode == "scale": - cluster.scale(maximum) + cluster.adapt(**adapt_kwargs) client = Client(cluster) + port = client.scheduler_info()["services"]["dashboard"] # Forward ports @@ -299,8 +301,6 @@ def main(args: configargparse.Namespace) -> None: dask_config=args.imager_dask_config, use_mpi=args.use_mpi, port_forward=args.port_forward, - minimum=1, - maximum=38, overload=True, ) @@ -355,9 +355,6 @@ def main(args: configargparse.Namespace) -> None: dask_config=args.dask_config, use_mpi=args.use_mpi, port_forward=args.port_forward, - minimum=1, - maximum=256, - mode="scale", ) # Define flow diff --git a/arrakis/utils/meta.py b/arrakis/utils/meta.py index 00b109fe..7c08dd2e 100644 --- a/arrakis/utils/meta.py +++ b/arrakis/utils/meta.py @@ -1,6 +1,7 @@ #!/usr/bin/env python """Generic program utilities""" +import importlib import warnings from itertools import zip_longest @@ -11,6 +12,24 @@ warnings.simplefilter("ignore", category=AstropyWarning) +# From https://stackoverflow.com/questions/1176136/convert-string-to-python-class-object +def class_for_name(module_name: str, class_name: str) -> object: + """Returns a class object given a module name and class name + + Args: + module_name (str): Module name + class_name (str): Class name + + Returns: + object: Class object + """ + # load the module, will raise ImportError if module cannot be loaded + m = importlib.import_module(module_name) + # get the class, will raise AttributeError if class cannot be found + c = getattr(m, class_name) + return c + + # stolen from https://stackoverflow.com/questions/32954486/zip-iterators-asserting-for-equal-length-in-python def zip_equal(*iterables): sentinel = object() From 9778d9f6d1819277e5f1edda413ead8ff87806bc Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Wed, 13 Dec 2023 13:58:39 +1100 Subject: [PATCH 33/46] Update default --- arrakis/configs/default.yaml | 38 ++++++++++++------------------------ 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/arrakis/configs/default.yaml b/arrakis/configs/default.yaml index ca734acc..7e064aee 100644 --- a/arrakis/configs/default.yaml +++ b/arrakis/configs/default.yaml @@ -1,25 +1,13 @@ -# Set up for Magnus -cores: 24 -processes: 12 -name: 'spice-worker' -memory: "60GB" -project: 'ja3' -queue: 'workq' -n_workers: 1000 -walltime: '6:00:00' -job_extra: ['-M magnus'] -# interface for the workers -interface: "ipogif0" -log_directory: 'spice_logs' -env_extra: [ - 'export OMP_NUM_THREADS=1', - 'source /home/$(whoami)/.bashrc', - 'conda activate spice' -] -python: 'srun -n 1 -c 24 python' -extra: [ - "--lifetime", "11h", - "--lifetime-stagger", "5m", -] -death_timeout: 300 -local_directory: '/dev/shm' +# Set up for local mahine +cluster_class: "distributed.LocalCluster" +cluster_kwargs: + cores: 1 + processes: 1 + name: 'spice-worker' + memory: "8GB" +adapt_kwargs: + minimum: 1 + maximum: 8 + wait_count: 20 + target_duration: "300s" + interval: "30s" From ce4fc6dddfb1b807542ec02d2971ecef88a1a275 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Wed, 13 Dec 2023 13:59:06 +1100 Subject: [PATCH 34/46] These don't exist anymore --- arrakis/configs/galaxy.yaml | 25 ------------------------- arrakis/configs/magnus.yaml | 25 ------------------------- 2 files changed, 50 deletions(-) delete mode 100644 arrakis/configs/galaxy.yaml delete mode 100644 arrakis/configs/magnus.yaml diff --git a/arrakis/configs/galaxy.yaml b/arrakis/configs/galaxy.yaml deleted file mode 100644 index 1077a2ab..00000000 --- a/arrakis/configs/galaxy.yaml +++ /dev/null @@ -1,25 +0,0 @@ -# Set up for Galaxy -cores: 20 -processes: 20 -name: 'spice-worker' -memory: "60GB" -project: 'askaprt' -queue: 'workq' -walltime: '6:00:00' -n_workers: 1000 -job_extra: ['-M galaxy'] -# interface for the workers -interface: "ipogif0" -log_directory: 'spice_logs' -env_extra: [ - 'export OMP_NUM_THREADS=1', - 'source /home/$(whoami)/.bashrc', - 'conda activate spice' -] -python: 'srun -n 1 -c 20 python' -extra: [ - "--lifetime", "5h", - "--lifetime-stagger", "5m", -] -death_timeout: 300 -local_directory: '/dev/shm' diff --git a/arrakis/configs/magnus.yaml b/arrakis/configs/magnus.yaml deleted file mode 100644 index ca734acc..00000000 --- a/arrakis/configs/magnus.yaml +++ /dev/null @@ -1,25 +0,0 @@ -# Set up for Magnus -cores: 24 -processes: 12 -name: 'spice-worker' -memory: "60GB" -project: 'ja3' -queue: 'workq' -n_workers: 1000 -walltime: '6:00:00' -job_extra: ['-M magnus'] -# interface for the workers -interface: "ipogif0" -log_directory: 'spice_logs' -env_extra: [ - 'export OMP_NUM_THREADS=1', - 'source /home/$(whoami)/.bashrc', - 'conda activate spice' -] -python: 'srun -n 1 -c 24 python' -extra: [ - "--lifetime", "11h", - "--lifetime-stagger", "5m", -] -death_timeout: 300 -local_directory: '/dev/shm' From a5d0d3fdefa2ce6953907d9e98601501765f88ec Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Wed, 13 Dec 2023 16:49:46 +1100 Subject: [PATCH 35/46] Le bug --- arrakis/cutout.py | 1 - 1 file changed, 1 deletion(-) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index 68e68acd..dd8b6c2d 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -324,7 +324,6 @@ def unpack(list_sq: List[Union[List[T], None, T]]) -> List[T]: Returns: List[T]: List of things """ - logger.setLevel(logging.DEBUG) logger.debug(f"{list_sq=}") list_fl: List[T] = [] for i in list_sq: From 137dd103b65e9de09ecf6f0b6d6670708e20c0ff Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Thu, 14 Dec 2023 14:08:36 +1100 Subject: [PATCH 36/46] Cut --- arrakis/cutout.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index dd8b6c2d..e5ba0eb8 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -358,7 +358,6 @@ def cutout_islands( field (str): RACS field name. directory (str): Directory to store cutouts. host (str): MongoDB host. - client (Client): Dask client. username (str, optional): Mongo username. Defaults to None. password (str, optional): Mongo password. Defaults to None. verbose (bool, optional): Verbose output. Defaults to True. @@ -369,8 +368,7 @@ def cutout_islands( """ if stokeslist is None: stokeslist = ["I", "Q", "U", "V"] - client = get_client() - logger.debug(f"Client is {client}") + directory = os.path.abspath(directory) outdir = os.path.join(directory, "cutouts") @@ -427,7 +425,7 @@ def cutout_islands( ) # args = [a.result() for a in args] # flat_args = unpack.map(args) - flat_args = unpack(args) + flat_args = unpack.submit(args) cuts = cutout.map( cutout_args=flat_args, field=unmapped(field), From 8257d9164ad0b04377bf470b9e6ad873e5eea25b Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Thu, 14 Dec 2023 14:13:00 +1100 Subject: [PATCH 37/46] Get dask working --- arrakis/.default_config.txt | 2 - arrakis/process_spice.py | 461 ++++++++++++++++-------------------- 2 files changed, 201 insertions(+), 262 deletions(-) diff --git a/arrakis/.default_config.txt b/arrakis/.default_config.txt index 85cd91ba..d62e4980 100644 --- a/arrakis/.default_config.txt +++ b/arrakis/.default_config.txt @@ -3,8 +3,6 @@ # host: # Host of mongodb. # username: # Username of mongodb. # password: # Password of mongodb. -port: 8787 # Port to run Dask dashboard on. -# port_forward: # Platform to fowards dask port [None]. # dask_config: # Config file for Dask SlurmCLUSTER. # holofile: yanda: "1.3.0" diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index a221a658..c36fb851 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -3,7 +3,7 @@ import logging import os from pathlib import Path -from typing import Any +from typing import Any, Optional, Tuple import configargparse import pkg_resources @@ -11,7 +11,9 @@ from astropy.time import Time from dask.distributed import Client, LocalCluster from dask_mpi import initialize +from distributed.deploy.cluster import Cluster from prefect import flow +from prefect.task_runners import BaseTaskRunner from prefect_dask import DaskTaskRunner, get_dask_client from arrakis import ( @@ -31,156 +33,156 @@ @flow(name="Combining+Synthesis on Arrakis") -def process_spice(args, host: str) -> None: +def process_spice(args, host: str, task_runner: BaseTaskRunner) -> None: """Workflow to process the SPIRCE-RACS data Args: args (configargparse.Namespace): Configuration parameters for this run host (str): Host address of the mongoDB. """ - # TODO: Fix the type assigned to args. The `configargparse.Namespace` was causing issues - # with the pydantic validation used by prefect / flow. - outfile = f"{args.field}.pipe.test.fits" if args.outfile is None else args.outfile - with get_dask_client(): - previous_future = None - previous_future = ( - cutout.cutout_islands( - field=args.field, - directory=str(args.outdir), - host=host, - epoch=args.epoch, - username=args.username, - password=args.password, - pad=args.pad, - stokeslist=["I", "Q", "U"], - verbose_worker=args.verbose_worker, - dryrun=args.dryrun, - limit=args.limit, - ) - if not args.skip_cutout - else previous_future + previous_future = None + previous_future = ( + cutout.cutout_islands.with_options( + task_runner=task_runner, + )( + field=args.field, + directory=str(args.outdir), + host=host, + epoch=args.epoch, + username=args.username, + password=args.password, + pad=args.pad, + stokeslist=["I", "Q", "U"], + verbose_worker=args.verbose_worker, + dryrun=args.dryrun, + limit=args.limit, ) - - previous_future = ( - linmos.main( - field=args.field, - datadir=Path(args.outdir), - host=host, - epoch=args.epoch, - holofile=Path(args.holofile), - username=args.username, - password=args.password, - yanda=args.yanda, - yanda_img=args.yanda_image, - stokeslist=["I", "Q", "U"], - limit=args.limit, - ) - if not args.skip_linmos - else previous_future + if not args.skip_cutout + else previous_future + ) + + previous_future = ( + linmos.main.with_options( + task_runner=task_runner, + )( + field=args.field, + datadir=Path(args.outdir), + host=host, + epoch=args.epoch, + holofile=Path(args.holofile), + username=args.username, + password=args.password, + yanda=args.yanda, + yanda_img=args.yanda_image, + stokeslist=["I", "Q", "U"], + limit=args.limit, ) + if not args.skip_linmos + else previous_future + ) - previous_future = ( - cleanup.main( - datadir=args.outdir, - stokeslist=["I", "Q", "U"], - ) - if not args.skip_cleanup - else previous_future + previous_future = ( + cleanup.main.with_options(task_run)( + datadir=args.outdir, + stokeslist=["I", "Q", "U"], ) - - previous_future = ( - frion.main( - field=args.field, - outdir=args.outdir, - host=host, - epoch=args.epoch, - username=args.username, - password=args.password, - database=args.database, - ionex_server=args.ionex_server, - ionex_proxy_server=args.ionex_proxy_server, - ionex_formatter=args.ionex_formatter, - ionex_predownload=args.ionex_predownload, - limit=args.limit, - ) - if not args.skip_frion - else previous_future + if not args.skip_cleanup + else previous_future + ) + + previous_future = ( + frion.main.with_options(task_runner=task_runner)( + field=args.field, + outdir=args.outdir, + host=host, + epoch=args.epoch, + username=args.username, + password=args.password, + database=args.database, + ionex_server=args.ionex_server, + ionex_proxy_server=args.ionex_proxy_server, + ionex_formatter=args.ionex_formatter, + ionex_predownload=args.ionex_predownload, + limit=args.limit, ) - - previous_future = ( - rmsynth_oncuts.main( - field=args.field, - outdir=args.outdir, - host=host, - epoch=args.epoch, - username=args.username, - password=args.password, - dimension=args.dimension, - verbose=args.verbose, - database=args.database, - do_validate=args.validate, - limit=args.limit, - savePlots=args.savePlots, - weightType=args.weightType, - fitRMSF=args.fitRMSF, - phiMax_radm2=args.phiMax_radm2, - dPhi_radm2=args.dPhi_radm2, - nSamples=args.nSamples, - polyOrd=args.polyOrd, - noStokesI=args.noStokesI, - showPlots=args.showPlots, - not_RMSF=args.not_RMSF, - rm_verbose=args.rm_verbose, - debug=args.debug, - fit_function=args.fit_function, - tt0=args.tt0, - tt1=args.tt1, - ion=True, - do_own_fit=args.do_own_fit, - ) - if not args.skip_rmsynth - else previous_future + if not args.skip_frion + else previous_future + ) + + previous_future = ( + rmsynth_oncuts.main.with_options(task_runner=task_runner)( + field=args.field, + outdir=args.outdir, + host=host, + epoch=args.epoch, + username=args.username, + password=args.password, + dimension=args.dimension, + verbose=args.verbose, + database=args.database, + do_validate=args.validate, + limit=args.limit, + savePlots=args.savePlots, + weightType=args.weightType, + fitRMSF=args.fitRMSF, + phiMax_radm2=args.phiMax_radm2, + dPhi_radm2=args.dPhi_radm2, + nSamples=args.nSamples, + polyOrd=args.polyOrd, + noStokesI=args.noStokesI, + showPlots=args.showPlots, + not_RMSF=args.not_RMSF, + rm_verbose=args.rm_verbose, + debug=args.debug, + fit_function=args.fit_function, + tt0=args.tt0, + tt1=args.tt1, + ion=True, + do_own_fit=args.do_own_fit, ) - - previous_future = ( - rmclean_oncuts.main( - field=args.field, - outdir=args.outdir, - host=host, - epoch=args.epoch, - username=args.username, - password=args.password, - dimension=args.dimension, - database=args.database, - limit=args.limit, - cutoff=args.cutoff, - maxIter=args.maxIter, - gain=args.gain, - window=args.window, - showPlots=args.showPlots, - rm_verbose=args.rm_verbose, - wait_for=[previous_future], - ) - if not args.skip_rmclean - else previous_future + if not args.skip_rmsynth + else previous_future + ) + + previous_future = ( + rmclean_oncuts.main.with_options(task_runner=task_runner)( + field=args.field, + outdir=args.outdir, + host=host, + epoch=args.epoch, + username=args.username, + password=args.password, + dimension=args.dimension, + database=args.database, + limit=args.limit, + cutoff=args.cutoff, + maxIter=args.maxIter, + gain=args.gain, + window=args.window, + showPlots=args.showPlots, + rm_verbose=args.rm_verbose, + wait_for=[previous_future], ) - - previous_future = ( - makecat.main( - field=args.field, - host=host, - epoch=args.epoch, - username=args.username, - password=args.password, - verbose=args.verbose, - outfile=outfile, - wait_for=[previous_future], - ) - if not args.skip_cat - else previous_future + if not args.skip_rmclean + else previous_future + ) + + previous_future = ( + makecat.main.with_options(task_runner=task_runner)( + field=args.field, + host=host, + epoch=args.epoch, + username=args.username, + password=args.password, + verbose=args.verbose, + outfile=outfile, + wait_for=[previous_future], ) + if not args.skip_cat + else previous_future + ) def save_args(args: configargparse.Namespace) -> Path: @@ -202,25 +204,32 @@ def save_args(args: configargparse.Namespace) -> Path: return Path(args_yaml_f) -def create_client( +def create_dask_runner( dask_config: str, - use_mpi: bool, - port_forward: Any, overload: bool = False, -) -> Client: - logger.info("Creating a Client") +) -> DaskTaskRunner: + """Create a DaskTaskRunner + + Args: + dask_config (str): Configuraiton file for the DaskTaskRunner + overload (bool, optional): Overload the options for threadded work. Defaults to False. + + Returns: + DaskTaskRunner: The prefect DaskTaskRunner instance + """ + logger.setLevel(logging.INFO) + logger.info("Creating a Dask Task Runner.") if dask_config is None: config_dir = pkg_resources.resource_filename("arrakis", "configs") dask_config = f"{config_dir}/default.yaml" - # Following https://github.com/dask/dask-jobqueue/issues/499 with open(dask_config) as f: logger.info(f"Loading {dask_config}") - config: dict = yaml.safe_load(f) + yaml_config: dict = yaml.safe_load(f) - cluster_class_str = config.get("cluster_class", "distributed.LocalCluster") - cluster_kwargs = config.get("cluster_kwargs", {}) - adapt_kwargs = config.get("adapt_kwargs", {}) + cluster_class_str = yaml_config.get("cluster_class", "distributed.LocalCluster") + cluster_kwargs = yaml_config.get("cluster_kwargs", {}) + adapt_kwargs = yaml_config.get("adapt_kwargs", {}) if overload: logger.info("Overwriting config attributes.") @@ -228,51 +237,13 @@ def create_client( cluster_kwargs["cores"] = 1 cluster_kwargs["processes"] = 1 - if use_mpi: - initialize( - interface=cluster_kwargs["interface"], - local_directory=cluster_kwargs["local_directory"], - nthreads=cluster_kwargs["cores"] / cluster_kwargs["processes"], - ) - client = Client() - else: - cluster_class = class_for_name( - module_name=cluster_class_str.rsplit(".", 1)[0], - class_name=cluster_class_str.rsplit(".", 1)[1], - ) - cluster = cluster_class( - **cluster_kwargs, - ) - logger.debug(f"Submitted scripts will look like: \n {cluster.job_script()}") - - cluster.adapt(**adapt_kwargs) - - client = Client(cluster) + config = { + "cluster_class": cluster_class_str, + "cluster_kwargs": cluster_kwargs, + "adapt_kwargs": adapt_kwargs, + } - port = client.scheduler_info()["services"]["dashboard"] - - # Forward ports - if port_forward is not None: - for p in port_forward: - port_forward(port, p) - - # Prin out Dask client info - logger.info(client.scheduler_info()["services"]) - - return client - - -def create_dask_runner(*args, **kwargs) -> DaskTaskRunner: - """Internally creates a Client object via `create_client`, - and then initialises a DaskTaskRunner. - - Returns: - DaskTaskRunner: A Prefect dask based task runner - """ - client = create_client(*args, **kwargs) - - logger.info("Creating DaskTaskRunner") - return DaskTaskRunner(address=client.scheduler.address), client + return DaskTaskRunner(**config) def main(args: configargparse.Namespace) -> None: @@ -297,52 +268,49 @@ def main(args: configargparse.Namespace) -> None: if not args.skip_imager: # This is the client for the imager component of the arrakis # pipeline. - dask_runner, client = create_dask_runner( + dask_runner = create_dask_runner( dask_config=args.imager_dask_config, - use_mpi=args.use_mpi, - port_forward=args.port_forward, overload=True, ) logger.info("Obtained DaskTaskRunner, executing the imager workflow. ") - with performance_report_prefect( - f"arrakis-imaging-{args.field}-report-{Time.now().fits}.html" - ): - imager.main.with_options( - name=f"Arrakis Imaging -- {args.field}", task_runner=dask_runner - )( - msdir=args.msdir, - out_dir=args.outdir, - cutoff=args.psf_cutoff, - robust=args.robust, - pols=args.pols, - nchan=args.nchan, - local_rms=args.local_rms, - local_rms_window=args.local_rms_window, - size=args.size, - scale=args.scale, - mgain=args.mgain, - niter=args.niter, - nmiter=args.nmiter, - auto_mask=args.auto_mask, - force_mask_rounds=args.force_mask_rounds, - auto_threshold=args.auto_threshold, - minuv=args.minuv, - purge=args.purge, - taper=args.taper, - parallel_deconvolution=args.parallel, - gridder=args.gridder, - wsclean_path=Path(args.local_wsclean) - if args.local_wsclean - else args.hosted_wsclean, - multiscale=args.multiscale, - multiscale_scale_bias=args.multiscale_scale_bias, - absmem=args.absmem, - ms_glob_pattern=args.ms_glob_pattern, - data_column=args.data_column, - ) + imager.main.with_options( + name=f"Arrakis Imaging -- {args.field}", task_runner=dask_runner + )( + msdir=args.msdir, + out_dir=args.outdir, + cutoff=args.psf_cutoff, + robust=args.robust, + pols=args.pols, + nchan=args.nchan, + local_rms=args.local_rms, + local_rms_window=args.local_rms_window, + size=args.size, + scale=args.scale, + mgain=args.mgain, + niter=args.niter, + nmiter=args.nmiter, + auto_mask=args.auto_mask, + force_mask_rounds=args.force_mask_rounds, + auto_threshold=args.auto_threshold, + minuv=args.minuv, + purge=args.purge, + taper=args.taper, + parallel_deconvolution=args.parallel, + gridder=args.gridder, + wsclean_path=Path(args.local_wsclean) + if args.local_wsclean + else args.hosted_wsclean, + multiscale=args.multiscale, + multiscale_scale_bias=args.multiscale_scale_bias, + absmem=args.absmem, + ms_glob_pattern=args.ms_glob_pattern, + data_column=args.data_column, + ) + client = dask_runner._client + if client is not None: client.close() - del dask_runner + del dask_runner else: logger.warn("Skipping the image creation step. ") @@ -351,23 +319,14 @@ def main(args: configargparse.Namespace) -> None: return # This is the client and pipeline for the RM extraction - dask_runner_2, client = create_dask_runner( + dask_runner_2 = create_dask_runner( dask_config=args.dask_config, - use_mpi=args.use_mpi, - port_forward=args.port_forward, ) # Define flow - with performance_report_prefect( - f"arrakis-synthesis-{args.field}-report-{Time.now().fits}.html" - ): - process_spice.with_options( - name=f"Arrakis Synthesis -- {args.field}", task_runner=dask_runner_2 - )(args, host) - - # TODO: Access the client via the `dask_runner`. Perhaps a - # way to do this is to extend the DaskTaskRunner's - # destructor and have it create it then. + process_spice.with_options( + name=f"Arrakis Synthesis -- {args.field}", task_runner=dask_runner_2 + )(args, host, dask_runner_2) def cli(): @@ -420,24 +379,6 @@ def cli(): "--password", type=str, default=None, help="Password of mongodb." ) - # parser.add_argument( - # '--port', - # type=int, - # default=9999, - # help="Port to run Dask dashboard on." - # ) - parser.add_argument( - "--use_mpi", - action="store_true", - help="Use Dask-mpi to parallelise -- must use srun/mpirun to assign resources.", - ) - parser.add_argument( - "--port_forward", - default=None, - help="Platform to fowards dask port [None].", - nargs="+", - ) - parser.add_argument( "--dask_config", type=str, From edd8ff2fa988c76d78fcf52662b27fa0706d1837 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Thu, 14 Dec 2023 14:15:29 +1100 Subject: [PATCH 38/46] Fixup --- arrakis/process_spice.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index c36fb851..641c67cc 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -84,7 +84,7 @@ def process_spice(args, host: str, task_runner: BaseTaskRunner) -> None: ) previous_future = ( - cleanup.main.with_options(task_run)( + cleanup.main.with_options(task_runner=task_runner)( datadir=args.outdir, stokeslist=["I", "Q", "U"], ) @@ -641,8 +641,7 @@ def cli(): "--outfile", default=None, type=str, help="File to save table to [None]." ) args = parser.parse_args() - if not args.use_mpi: - parser.print_values() + parser.print_values() verbose = args.verbose if verbose: From 5423e0efe3cbcff89bf89ef48cf8070d849697e2 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Thu, 14 Dec 2023 14:16:05 +1100 Subject: [PATCH 39/46] ruff it up --- arrakis/cutout.py | 1 - arrakis/linmos.py | 1 - arrakis/makecat.py | 2 +- arrakis/process_spice.py | 9 ++------- 4 files changed, 3 insertions(+), 10 deletions(-) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index e5ba0eb8..fa9705e3 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -21,7 +21,6 @@ from astropy.utils.exceptions import AstropyWarning from astropy.wcs.utils import skycoord_to_pixel from dask.distributed import Client, LocalCluster -from distributed import get_client from prefect import flow, task, unmapped from spectral_cube import SpectralCube from spectral_cube.utils import SpectralCubeWarning diff --git a/arrakis/linmos.py b/arrakis/linmos.py index 165c6990..17dfd8d0 100644 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -15,7 +15,6 @@ from astropy.utils.exceptions import AstropyWarning from dask.distributed import Client, LocalCluster from prefect import flow, task, unmapped -from prefect.utilities.annotations import quote from racs_tools import beamcon_3D from spectral_cube.utils import SpectralCubeWarning from spython.main import Client as sclient diff --git a/arrakis/makecat.py b/arrakis/makecat.py index 71480770..1283d627 100644 --- a/arrakis/makecat.py +++ b/arrakis/makecat.py @@ -17,7 +17,7 @@ from astropy.stats import sigma_clip from astropy.table import Column, Table from dask.diagnostics import ProgressBar -from prefect import flow, task, unmapped +from prefect import flow, task from rmtable import RMTable from scipy.stats import lognorm, norm from tqdm import tqdm diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index 641c67cc..2951451a 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -3,18 +3,14 @@ import logging import os from pathlib import Path -from typing import Any, Optional, Tuple import configargparse import pkg_resources import yaml from astropy.time import Time -from dask.distributed import Client, LocalCluster -from dask_mpi import initialize -from distributed.deploy.cluster import Cluster from prefect import flow from prefect.task_runners import BaseTaskRunner -from prefect_dask import DaskTaskRunner, get_dask_client +from prefect_dask import DaskTaskRunner from arrakis import ( cleanup, @@ -28,8 +24,7 @@ ) from arrakis.logger import logger from arrakis.utils.database import test_db -from arrakis.utils.meta import class_for_name -from arrakis.utils.pipeline import logo_str, performance_report_prefect +from arrakis.utils.pipeline import logo_str @flow(name="Combining+Synthesis on Arrakis") From 4de34f546c05041489d5a897f6fe61cc637bccb1 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Thu, 14 Dec 2023 14:17:13 +1100 Subject: [PATCH 40/46] Ruff --- arrakis/cutout.py | 1 - 1 file changed, 1 deletion(-) diff --git a/arrakis/cutout.py b/arrakis/cutout.py index fa9705e3..cf79fb76 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -14,7 +14,6 @@ import astropy.units as u import numpy as np import pymongo -from astropy import units as u from astropy.coordinates import Latitude, Longitude, SkyCoord from astropy.io import fits from astropy.utils import iers From 5e0070ea567facd3e894dc624be7f8cd6c7ccdc2 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Thu, 14 Dec 2023 14:58:52 +1100 Subject: [PATCH 41/46] Add process regioon --- arrakis/merge_fields.py | 198 ++++++++++++++++++++++++-------------- arrakis/process_region.py | 65 +++---------- arrakis/process_spice.py | 2 - 3 files changed, 135 insertions(+), 130 deletions(-) diff --git a/arrakis/merge_fields.py b/arrakis/merge_fields.py index 0e4786da..603539b4 100644 --- a/arrakis/merge_fields.py +++ b/arrakis/merge_fields.py @@ -3,11 +3,12 @@ import os from pprint import pformat from shutil import copyfile -from typing import Dict, List, Union +from typing import Dict, List, Optional, Union import pymongo from dask import delayed from dask.distributed import Client, LocalCluster +from prefect import flow, task, unmapped from arrakis.linmos import get_yanda, linmos from arrakis.logger import logger @@ -23,49 +24,73 @@ def make_short_name(name: str) -> str: return short -@delayed +@task(name="Copy singleton island") def copy_singleton( - beam: dict, vals: dict, merge_name: str, field_dir: str, data_dir: str -) -> pymongo.UpdateOne: - try: - i_file_old = os.path.join(field_dir, vals["i_file"]) - q_file_old = os.path.join(field_dir, vals["q_file_ion"]) - u_file_old = os.path.join(field_dir, vals["u_file_ion"]) - except KeyError: - raise KeyError("Ion files not found. Have you run FRion?") - new_dir = os.path.join(data_dir, beam["Source_ID"]) + beam: dict, field_dict: Dict[str, str], merge_name: str, data_dir: str +) -> List[pymongo.UpdateOne]: + """Copy an island within a single field to the merged field - try_mkdir(new_dir, verbose=False) + Args: + beam (dict): Beam document + field_dict (Dict[str, str]): Field dictionary + merge_name (str): Merged field name + data_dir (str): Output directory - i_file_new = os.path.join(new_dir, os.path.basename(i_file_old)).replace( - ".fits", ".edge.linmos.fits" - ) - q_file_new = os.path.join(new_dir, os.path.basename(q_file_old)).replace( - ".fits", ".edge.linmos.fits" - ) - u_file_new = os.path.join(new_dir, os.path.basename(u_file_old)).replace( - ".fits", ".edge.linmos.fits" - ) + Raises: + KeyError: If ion files not found - for src, dst in zip( - [i_file_old, q_file_old, u_file_old], [i_file_new, q_file_new, u_file_new] - ): - copyfile(src, dst) - src_weight = src.replace(".image.restored.", ".weights.").replace(".ion", "") - dst_weight = dst.replace(".image.restored.", ".weights.").replace(".ion", "") - copyfile(src_weight, dst_weight) - - query = {"Source_ID": beam["Source_ID"]} - newvalues = { - "$set": { - f"beams.{merge_name}.i_file": make_short_name(i_file_new), - f"beams.{merge_name}.q_file": make_short_name(q_file_new), - f"beams.{merge_name}.u_file": make_short_name(u_file_new), - f"beams.{merge_name}.DR1": True, + Returns: + List[pymongo.UpdateOne]: Database updates + """ + updates = [] + for field, vals in beam["beams"].items(): + if field not in field_dict.keys(): + continue + field_dir = field_dict[field] + try: + i_file_old = os.path.join(field_dir, vals["i_file"]) + q_file_old = os.path.join(field_dir, vals["q_file_ion"]) + u_file_old = os.path.join(field_dir, vals["u_file_ion"]) + except KeyError: + raise KeyError("Ion files not found. Have you run FRion?") + new_dir = os.path.join(data_dir, beam["Source_ID"]) + + try_mkdir(new_dir, verbose=False) + + i_file_new = os.path.join(new_dir, os.path.basename(i_file_old)).replace( + ".fits", ".edge.linmos.fits" + ) + q_file_new = os.path.join(new_dir, os.path.basename(q_file_old)).replace( + ".fits", ".edge.linmos.fits" + ) + u_file_new = os.path.join(new_dir, os.path.basename(u_file_old)).replace( + ".fits", ".edge.linmos.fits" + ) + + for src, dst in zip( + [i_file_old, q_file_old, u_file_old], [i_file_new, q_file_new, u_file_new] + ): + copyfile(src, dst) + src_weight = src.replace(".image.restored.", ".weights.").replace( + ".ion", "" + ) + dst_weight = dst.replace(".image.restored.", ".weights.").replace( + ".ion", "" + ) + copyfile(src_weight, dst_weight) + + query = {"Source_ID": beam["Source_ID"]} + newvalues = { + "$set": { + f"beams.{merge_name}.i_file": make_short_name(i_file_new), + f"beams.{merge_name}.q_file": make_short_name(q_file_new), + f"beams.{merge_name}.u_file": make_short_name(u_file_new), + f"beams.{merge_name}.DR1": True, + } } - } - return pymongo.UpdateOne(query, newvalues) + updates.append(pymongo.UpdateOne(query, newvalues)) + return updates def copy_singletons( @@ -73,7 +98,18 @@ def copy_singletons( data_dir: str, beams_col: pymongo.collection.Collection, merge_name: str, -) -> list: +) -> List[pymongo.UpdateOne]: + """Copy islands that don't overlap other fields + + Args: + field_dict (Dict[str, str]): Field dictionary + data_dir (str): Data directory + beams_col (pymongo.collection.Collection): Beams collection + merge_name (str): Merged field name + + Returns: + List[pymongo.UpdateOne]: Database updates + """ # Find all islands with the given fields that DON'T overlap another field query = { "$or": [ @@ -92,18 +128,15 @@ def copy_singletons( big_beams = list( beams_col.find({"Source_ID": {"$in": island_ids}}).sort("Source_ID") ) - updates = [] - for beam in big_beams: - for field, vals in beam["beams"].items(): - if field not in field_dict.keys(): - continue - field_dir = field_dict[field] - update = copy_singleton(beam, vals, merge_name, field_dir, data_dir) - updates.append(update) + updates = copy_singleton.map( + beam=big_beams, + field_dict=unmapped(field_dict), + merge_name=unmapped(merge_name), + data_dir=unmapped(data_dir), + ) return updates -@delayed def genparset( old_ims: list, stokes: str, @@ -139,10 +172,24 @@ def genparset( return parset_file -# @delayed(nout=3) def merge_multiple_field( beam: dict, field_dict: dict, merge_name: str, data_dir: str, image: str -) -> list: +) -> List[pymongo.UpdateOne]: + """Merge an island that overlaps multiple fields + + Args: + beam (dict): Beam document + field_dict (dict): Field dictionary + merge_name (str): Merged field name + data_dir (str): Data directory + image (str): Yandasoft image + + Raises: + KeyError: If ion files not found + + Returns: + List[pymongo.UpdateOne]: Database updates + """ i_files_old = [] q_files_old = [] u_files_old = [] @@ -167,19 +214,32 @@ def merge_multiple_field( for stokes, imlist in zip(["I", "Q", "U"], [i_files_old, q_files_old, u_files_old]): parset_file = genparset(imlist, stokes, new_dir) - update = linmos(parset_file, merge_name, image) + update = linmos.fn(parset_file, merge_name, image) updates.append(update) return updates +@task(name="Merge multiple islands") def merge_multiple_fields( field_dict: Dict[str, str], data_dir: str, beams_col: pymongo.collection.Collection, merge_name: str, image: str, -) -> list: +) -> List[pymongo.UpdateOne]: + """Merge multiple islands that overlap multiple fields + + Args: + field_dict (Dict[str, str]): Field dictionary + data_dir (str): Data directory + beams_col (pymongo.collection.Collection): Beams collection + merge_name (str): Merged field name + image (str): Yandasoft image + + Returns: + List[pymongo.UpdateOne]: Database updates + """ # Find all islands with the given fields that overlap another field query = { "$or": [ @@ -199,14 +259,18 @@ def merge_multiple_fields( beams_col.find({"Source_ID": {"$in": island_ids}}).sort("Source_ID") ) - updates = [] - for beam in big_beams: - update = merge_multiple_field(beam, field_dict, merge_name, data_dir, image) - updates.extend(update) + updates = merge_multiple_field.map( + beam=big_beams, + field_dict=unmapped(field_dict), + merge_name=unmapped(merge_name), + data_dir=unmapped(data_dir), + image=unmapped(image), + ) return updates +@flow(name="Merge fields") def main( fields: List[str], field_dirs: List[str], @@ -214,10 +278,9 @@ def main( output_dir: str, host: str, epoch: int, - username: Union[str, None] = None, - password: Union[str, None] = None, + username: Optional[str] = None, + password: Optional[str] = None, yanda="1.3.0", - verbose: bool = True, ) -> str: logger.debug(f"{fields=}") @@ -257,21 +320,8 @@ def main( image=image, ) - singleton_futures = chunk_dask( - outputs=singleton_updates, - task_name="singleton islands", - progress_text="Copying singleton islands", - verbose=verbose, - ) - singleton_comp = [f.compute() for f in singleton_futures] - - multiple_futures = chunk_dask( - outputs=mutilple_updates, - task_name="overlapping islands", - progress_text="Running LINMOS on overlapping islands", - verbose=verbose, - ) - multiple_comp = [f.compute() for f in multiple_futures] + singleton_comp = [f.result() for f in singleton_updates] + multiple_comp = [f.result() for f in mutilple_updates] for m in multiple_comp: m._doc["$set"].update({f"beams.{merge_name}.DR1": True}) diff --git a/arrakis/process_region.py b/arrakis/process_region.py index aacea0f5..125a379e 100644 --- a/arrakis/process_region.py +++ b/arrakis/process_region.py @@ -10,18 +10,17 @@ from dask_jobqueue import SLURMCluster from dask_mpi import initialize from prefect import flow, task +from prefect.task_runners import BaseTaskRunner from prefect_dask import DaskTaskRunner -from arrakis import merge_fields, process_spice +from arrakis import makecat, merge_fields, process_spice, rmclean_oncuts, rmsynth_oncuts from arrakis.logger import logger from arrakis.utils.database import test_db from arrakis.utils.pipeline import logo_str, port_forward -merge_task = task(merge_fields.main, name="Merge fields") - @flow -def process_merge(args, host: str, inter_dir: str) -> None: +def process_merge(args, host: str, inter_dir: str, task_runner) -> None: """Workflow to merge spectra from overlapping fields together Args: @@ -31,7 +30,7 @@ def process_merge(args, host: str, inter_dir: str) -> None: """ previous_future = None previous_future = ( - merge_task.submit( + merge_fields.with_options(task_runner=task_runner)( fields=args.fields, field_dirs=args.datadirs, merge_name=args.merge_name, @@ -48,7 +47,7 @@ def process_merge(args, host: str, inter_dir: str) -> None: ) previous_future = ( - process_spice.rmsynth_task.submit( + rmsynth_oncuts.main.with_options(task_runner=task_runner)( field=args.merge_name, outdir=inter_dir, host=host, @@ -77,14 +76,13 @@ def process_merge(args, host: str, inter_dir: str) -> None: tt1=args.tt1, ion=False, do_own_fit=args.do_own_fit, - wait_for=[previous_future], ) if not args.skip_rmsynth else previous_future ) previous_future = ( - process_spice.rmclean_task.submit( + rmclean_oncuts.main.with_options(task_runner=task_runner)( field=args.merge_name, outdir=inter_dir, host=host, @@ -102,14 +100,13 @@ def process_merge(args, host: str, inter_dir: str) -> None: window=args.window, showPlots=args.showPlots, rm_verbose=args.rm_verbose, - wait_for=[previous_future], ) if not args.skip_rmclean else previous_future ) previous_future = ( - process_spice.cat_task.submit( + makecat.main.with_options(task_runner=task_runner)( field=args.merge_name, host=host, epoch=args.epoch, @@ -117,7 +114,6 @@ def process_merge(args, host: str, inter_dir: str) -> None: password=args.password, verbose=args.verbose, outfile=args.outfile, - wait_for=[previous_future], ) if not args.skip_cat else previous_future @@ -130,8 +126,6 @@ def main(args: configargparse.Namespace) -> None: Args: args (configargparse.Namespace): Command line arguments. """ - host = args.host - if args.dask_config is None: config_dir = pkg_resources.resource_filename("arrakis", "configs") args.dask_config = os.path.join(config_dir, "default.yaml") @@ -139,33 +133,6 @@ def main(args: configargparse.Namespace) -> None: if args.outfile is None: args.outfile = f"{args.merge_name}.pipe.test.fits" - # Following https://github.com/dask/dask-jobqueue/issues/499 - with open(args.dask_config) as f: - config = yaml.safe_load(f) - - config.update( - { - # 'scheduler_options': { - # "dashboard_address": f":{args.port}" - # }, - "log_directory": f"{args.merge_name}_{Time.now().fits}_spice_logs/" - } - ) - if args.use_mpi: - initialize( - interface=config["interface"], - local_directory=config["local_directory"], - nthreads=config["cores"] / config["processes"], - ) - client = Client() - else: - cluster = SLURMCluster( - **config, - ) - logger.debug(f"Submitted scripts will look like: \n {cluster.job_script()}") - - client = Client(cluster) - test_db( host=args.host, username=args.username, @@ -178,25 +145,15 @@ def main(args: configargparse.Namespace) -> None: with open(args_yaml_f, "w") as f: f.write(args_yaml) - port = client.scheduler_info()["services"]["dashboard"] - - # Forward ports - if args.port_forward is not None: - for p in args.port_forward: - port_forward(port, p) - - # Prin out Dask client info - logger.info(client.scheduler_info()["services"]) - - dask_runner = DaskTaskRunner(address=client.scheduler.address) + dask_runner = process_spice.create_dask_runner( + dask_config=args.dask_config, + ) inter_dir = os.path.join(os.path.abspath(args.output_dir), args.merge_name) process_merge.with_options( name=f"SPICE-RACS: {args.merge_name}", task_runner=dask_runner - )(args, host, inter_dir) - - client.close() + )(args, args.host, inter_dir, dask_runner) def cli(): diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index 2951451a..d9ebe382 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -158,7 +158,6 @@ def process_spice(args, host: str, task_runner: BaseTaskRunner) -> None: window=args.window, showPlots=args.showPlots, rm_verbose=args.rm_verbose, - wait_for=[previous_future], ) if not args.skip_rmclean else previous_future @@ -173,7 +172,6 @@ def process_spice(args, host: str, task_runner: BaseTaskRunner) -> None: password=args.password, verbose=args.verbose, outfile=outfile, - wait_for=[previous_future], ) if not args.skip_cat else previous_future From 536063798dd518210cb7ba3d995174192fa0434d Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Thu, 14 Dec 2023 14:59:24 +1100 Subject: [PATCH 42/46] Ruff --- arrakis/merge_fields.py | 4 +--- arrakis/process_region.py | 9 ++------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/arrakis/merge_fields.py b/arrakis/merge_fields.py index 603539b4..2e215355 100644 --- a/arrakis/merge_fields.py +++ b/arrakis/merge_fields.py @@ -3,10 +3,9 @@ import os from pprint import pformat from shutil import copyfile -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional import pymongo -from dask import delayed from dask.distributed import Client, LocalCluster from prefect import flow, task, unmapped @@ -14,7 +13,6 @@ from arrakis.logger import logger from arrakis.utils.database import get_db, test_db from arrakis.utils.io import try_mkdir -from arrakis.utils.pipeline import chunk_dask def make_short_name(name: str) -> str: diff --git a/arrakis/process_region.py b/arrakis/process_region.py index 125a379e..14cfdaf6 100644 --- a/arrakis/process_region.py +++ b/arrakis/process_region.py @@ -6,17 +6,12 @@ import pkg_resources import yaml from astropy.time import Time -from dask.distributed import Client -from dask_jobqueue import SLURMCluster -from dask_mpi import initialize -from prefect import flow, task -from prefect.task_runners import BaseTaskRunner -from prefect_dask import DaskTaskRunner +from prefect import flow from arrakis import makecat, merge_fields, process_spice, rmclean_oncuts, rmsynth_oncuts from arrakis.logger import logger from arrakis.utils.database import test_db -from arrakis.utils.pipeline import logo_str, port_forward +from arrakis.utils.pipeline import logo_str @flow From 65bf1455bf9ad6dfc00b950f937704fa0f2a088c Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Thu, 14 Dec 2023 15:09:28 +1100 Subject: [PATCH 43/46] Fix args --- arrakis/merge_fields.py | 5 ----- arrakis/process_region.py | 5 +---- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/arrakis/merge_fields.py b/arrakis/merge_fields.py index 2e215355..dd4701ea 100644 --- a/arrakis/merge_fields.py +++ b/arrakis/merge_fields.py @@ -393,10 +393,6 @@ def cli(): help="Epoch of observation.", ) - parser.add_argument( - "-v", dest="verbose", action="store_true", help="Verbose output [False]." - ) - parser.add_argument( "--username", type=str, default=None, help="Username of mongodb." ) @@ -425,7 +421,6 @@ def cli(): username=args.username, password=args.password, yanda=args.yanda, - verbose=verbose, ) client.close() diff --git a/arrakis/process_region.py b/arrakis/process_region.py index 14cfdaf6..768620ad 100644 --- a/arrakis/process_region.py +++ b/arrakis/process_region.py @@ -35,7 +35,6 @@ def process_merge(args, host: str, inter_dir: str, task_runner) -> None: username=args.username, password=args.password, yanda=args.yanda, - verbose=args.verbose, ) if not args.skip_merge else previous_future @@ -52,7 +51,7 @@ def process_merge(args, host: str, inter_dir: str, task_runner) -> None: dimension=args.dimension, verbose=args.verbose, database=args.database, - validate=args.validate, + do_validate=args.validate, limit=args.limit, savePlots=args.savePlots, weightType=args.weightType, @@ -85,9 +84,7 @@ def process_merge(args, host: str, inter_dir: str, task_runner) -> None: username=args.username, password=args.password, dimension=args.dimension, - verbose=args.verbose, database=args.database, - validate=args.validate, limit=args.limit, cutoff=args.cutoff, maxIter=args.maxIter, From 0ff3bfcc1238f731f456ffd16bd38020e6ecebfd Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Thu, 14 Dec 2023 18:42:06 +1100 Subject: [PATCH 44/46] Use real RM-Tools --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6fe2199b..45da152a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ bokeh = "<3" prefect = "^2" prefect-dask = "*" RMTable = { git = "https://github.com/CIRADA-Tools/RMTable" } -rm-tools = {git = "https://github.com/AlecThomson/RM-Tools.git", branch="spiceracs_dev"} +rm-tools = {git = "https://github.com/CIRADA-Tools/RM-Tools"} PolSpectra = { git = "https://github.com/AlecThomson/PolSpectra.git", branch="spiceracs"} setuptools = "*" fixms = "^0.1.2" From bf28764254e4cb48f8db030542ae2cda5fcea371 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Thu, 14 Dec 2023 18:44:04 +1100 Subject: [PATCH 45/46] Version bump --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 45da152a..2fb7a7dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "arrakis" -version = "2.0.0" +version = "2.1.0" description = "Processing the SPICE." homepage = "https://research.csiro.au/racs/" repository = "https://github.com/AlecThomson/arrakis" From c2132bf728dfb9f113a614776d82b77c0098dc50 Mon Sep 17 00:00:00 2001 From: "Thomson, Alec (CASS, Kensington)" Date: Thu, 14 Dec 2023 18:47:43 +1100 Subject: [PATCH 46/46] Remove dask --- arrakis/cleanup.py | 6 ---- arrakis/cutout.py | 6 ---- arrakis/frion.py | 6 ---- arrakis/imager.py | 68 +++++++++++++++----------------------- arrakis/linmos.py | 3 -- arrakis/merge_fields.py | 7 ---- arrakis/rmclean_oncuts.py | 6 ---- arrakis/rmsynth_oncuts.py | 7 ---- scripts/casda_prepare.py | 41 ++++++----------------- scripts/compare_leakage.py | 9 ----- submit/test_image.py | 3 -- 11 files changed, 38 insertions(+), 124 deletions(-) diff --git a/arrakis/cleanup.py b/arrakis/cleanup.py index d9c9d9b5..7e8cc601 100644 --- a/arrakis/cleanup.py +++ b/arrakis/cleanup.py @@ -99,14 +99,8 @@ def cli(): if verbose: logger.setLevel(logging.DEBUG) - cluster = LocalCluster(n_workers=20) - client = Client(cluster) - main(datadir=Path(args.outdir), stokeslist=None, verbose=verbose) - client.close() - cluster.close() - if __name__ == "__main__": cli() diff --git a/arrakis/cutout.py b/arrakis/cutout.py index cf79fb76..92b4a231 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -569,12 +569,6 @@ def cli() -> None: if verbose: logger.setLevel(logging.INFO) - cluster = LocalCluster( - n_workers=12, threads_per_worker=1, dashboard_address=":9898" - ) - client = Client(cluster) - logger.info(client) - test_db( host=args.host, username=args.username, diff --git a/arrakis/frion.py b/arrakis/frion.py index 0ee0f673..392fc08f 100644 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -418,12 +418,6 @@ def cli(): if verbose: logger.setLevel(logging.INFO) - cluster = LocalCluster( - n_workers=10, processes=True, threads_per_worker=1, local_directory="/dev/shm" - ) - client = Client(cluster) - logger.info(client) - test_db(host=args.host, username=args.username, password=args.password) main( diff --git a/arrakis/imager.py b/arrakis/imager.py index c7cb1f87..06997354 100644 --- a/arrakis/imager.py +++ b/arrakis/imager.py @@ -914,48 +914,34 @@ def imager_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli(): """Command-line interface""" parser = imager_parser() - args = parser.parse_args() - - if args.mpi: - initialize(interface="ipogif0") - cluster = None - - else: - cluster = LocalCluster( - threads_per_worker=1, - ) - - with Client(cluster) as client: - logger.debug(f"{cluster=}") - logger.debug(f"{client=}") - main( - msdir=args.msdir, - out_dir=args.outdir, - cutoff=args.psf_cutoff, - robust=args.robust, - pols=args.pols, - nchan=args.nchan, - size=args.size, - scale=args.scale, - mgain=args.mgain, - niter=args.niter, - auto_mask=args.auto_mask, - force_mask_rounds=args.force_mask_rounds, - auto_threshold=args.auto_threshold, - minuv=args.minuv, - purge=args.purge, - taper=args.taper, - parallel_deconvolution=args.parallel, - gridder=args.gridder, - wsclean_path=Path(args.local_wsclean) - if args.local_wsclean - else args.hosted_wsclean, - multiscale=args.multiscale, - ms_glob_pattern=args.ms_glob_pattern, - data_column=args.data_column, - skip_fix_ms=args.skip_fix_ms, - ) + main( + msdir=args.msdir, + out_dir=args.outdir, + cutoff=args.psf_cutoff, + robust=args.robust, + pols=args.pols, + nchan=args.nchan, + size=args.size, + scale=args.scale, + mgain=args.mgain, + niter=args.niter, + auto_mask=args.auto_mask, + force_mask_rounds=args.force_mask_rounds, + auto_threshold=args.auto_threshold, + minuv=args.minuv, + purge=args.purge, + taper=args.taper, + parallel_deconvolution=args.parallel, + gridder=args.gridder, + wsclean_path=Path(args.local_wsclean) + if args.local_wsclean + else args.hosted_wsclean, + multiscale=args.multiscale, + ms_glob_pattern=args.ms_glob_pattern, + data_column=args.data_column, + skip_fix_ms=args.skip_fix_ms, + ) if __name__ == "__main__": diff --git a/arrakis/linmos.py b/arrakis/linmos.py index 17dfd8d0..e68cd110 100644 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -458,9 +458,6 @@ def cli(): args = parser.parse_args() - cluster = LocalCluster(n_workers=1) - client = Client(cluster) - verbose = args.verbose test_db( host=args.host, username=args.username, password=args.password, verbose=verbose diff --git a/arrakis/merge_fields.py b/arrakis/merge_fields.py index dd4701ea..675ca6c4 100644 --- a/arrakis/merge_fields.py +++ b/arrakis/merge_fields.py @@ -402,10 +402,6 @@ def cli(): ) args = parser.parse_args() - - cluster = LocalCluster() - client = Client(cluster) - verbose = args.verbose test_db( host=args.host, username=args.username, password=args.password, verbose=verbose @@ -423,9 +419,6 @@ def cli(): yanda=args.yanda, ) - client.close() - cluster.close() - if __name__ == "__main__": cli() diff --git a/arrakis/rmclean_oncuts.py b/arrakis/rmclean_oncuts.py index b697ad3d..0f8e5173 100644 --- a/arrakis/rmclean_oncuts.py +++ b/arrakis/rmclean_oncuts.py @@ -451,9 +451,6 @@ def cli(): args = parser.parse_args() - cluster = LocalCluster(n_workers=20) - client = Client(cluster) - verbose = args.verbose rmv = args.rm_verbose host = args.host @@ -487,9 +484,6 @@ def cli(): rm_verbose=args.rm_verbose, ) - client.close() - cluster.close() - if __name__ == "__main__": cli() diff --git a/arrakis/rmsynth_oncuts.py b/arrakis/rmsynth_oncuts.py index 5296b199..e2023d98 100644 --- a/arrakis/rmsynth_oncuts.py +++ b/arrakis/rmsynth_oncuts.py @@ -1205,13 +1205,6 @@ def cli(): elif verbose: logger.setLevel(logger.INFO) - cluster = LocalCluster( - # n_workers=12, processes=True, threads_per_worker=1, - local_directory="/dev/shm" - ) - client = Client(cluster) - logger.debug(client) - test_db( host=args.host, username=args.username, password=args.password, verbose=verbose ) diff --git a/scripts/casda_prepare.py b/scripts/casda_prepare.py index 3d0c166f..8933832c 100755 --- a/scripts/casda_prepare.py +++ b/scripts/casda_prepare.py @@ -802,36 +802,17 @@ def cli(): logger.setLevel(logging.INFO) elif args.debug: logger.setLevel(logging.DEBUG) - - if args.mpi: - initialize( - interface=args.interface, - local_directory="/dev/shm", - ) - cluster = None - else: - cluster = LocalCluster( - n_workers=12, - processes=True, - threads_per_worker=1, - local_directory="/dev/shm", - ) - - with Client( - cluster, - ) as client: - logger.debug(f"{client=}") - main( - polcatf=args.polcat, - data_dir=args.data_dir, - prep_type=args.prep_type, - do_update_cubes=args.convert_cubes, - do_convert_spectra=args.convert_spectra, - do_convert_plots=args.convert_plots, - verbose=args.verbose, - batch_size=args.batch_size, - outdir=args.outdir, - ) + main( + polcatf=args.polcat, + data_dir=args.data_dir, + prep_type=args.prep_type, + do_update_cubes=args.convert_cubes, + do_convert_spectra=args.convert_spectra, + do_convert_plots=args.convert_plots, + verbose=args.verbose, + batch_size=args.batch_size, + outdir=args.outdir, + ) if __name__ == "__main__": diff --git a/scripts/compare_leakage.py b/scripts/compare_leakage.py index 1ae86b73..8b7dfa62 100644 --- a/scripts/compare_leakage.py +++ b/scripts/compare_leakage.py @@ -265,12 +265,6 @@ def cli(): args = parser.parse_args() - cluster = LocalCluster( - n_workers=10, - threads_per_worker=1, - ) - client = Client(cluster) - if args.verbose: logger.setLevel(logging.INFO) @@ -285,9 +279,6 @@ def cli(): snr_cut=args.snr, ) - client.close() - cluster.close() - if __name__ == "__main__": cli() diff --git a/submit/test_image.py b/submit/test_image.py index 59b0a4bb..974617c8 100755 --- a/submit/test_image.py +++ b/submit/test_image.py @@ -37,9 +37,6 @@ def main(): ) cluster.scale(72) logger.debug(f"Submitted scripts will look like: \n {cluster.job_script()}") - # # exit() - # cluster = LocalCluster(n_workers=10, threads_per_worker=1) - # cluster.adapt(minimum=1, maximum=36) client = Client(cluster)