diff --git a/arrakis/.default_config.txt b/arrakis/.default_config.txt index 85cd91ba..d62e4980 100644 --- a/arrakis/.default_config.txt +++ b/arrakis/.default_config.txt @@ -3,8 +3,6 @@ # host: # Host of mongodb. # username: # Username of mongodb. # password: # Password of mongodb. -port: 8787 # Port to run Dask dashboard on. -# port_forward: # Platform to fowards dask port [None]. # dask_config: # Config file for Dask SlurmCLUSTER. # holofile: yanda: "1.3.0" diff --git a/arrakis/cleanup.py b/arrakis/cleanup.py index e15d3431..7e8cc601 100644 --- a/arrakis/cleanup.py +++ b/arrakis/cleanup.py @@ -6,35 +6,38 @@ from pathlib import Path from typing import List, Union -from dask import delayed from dask.distributed import Client, LocalCluster +from prefect import flow, task, unmapped from arrakis.logger import logger -from arrakis.utils.pipeline import chunk_dask, logo_str +from arrakis.utils.pipeline import logo_str logger.setLevel(logging.INFO) -@delayed -def cleanup(workdir: str, stoke: str) -> None: +@task(name="Cleanup directory") +def cleanup(workdir: str, stokeslist: List[str]) -> None: """Clean up beam images Args: workdir (str): Directory containing images stoke (str): Stokes parameter """ - # Clean up beam images - # old_files = glob(f"{workdir}/*.cutout.*.{stoke.lower()}.*beam[00-36]*.fits") - # for old in old_files: - # os.remove(old) + if os.path.basename(workdir) == "slurmFiles": + return + for stoke in stokeslist: + # Clean up beam images + # old_files = glob(f"{workdir}/*.cutout.*.{stoke.lower()}.*beam[00-36]*.fits") + # for old in old_files: + # os.remove(old) - pass + ... +@flow(name="Cleanup") def main( datadir: Path, stokeslist: Union[List[str], None] = None, - verbose=True, ) -> None: """Clean up beam images @@ -55,20 +58,9 @@ def main( if os.path.isdir(os.path.join(cutdir, name)) ] ) - - outputs = [] - for file in files: - if os.path.basename(file) == "slurmFiles": - continue - for stoke in stokeslist: - output = cleanup(file, stoke) - outputs.append(output) - - futures = chunk_dask( - outputs=outputs, - task_name="cleanup", - progress_text="Running cleanup", - verbose=verbose, + outputs = cleanup.map( + workdir=files, + stokeslist=unmapped(stokeslist), ) logger.info("Cleanup done!") @@ -107,14 +99,8 @@ def cli(): if verbose: logger.setLevel(logging.DEBUG) - cluster = LocalCluster(n_workers=20) - client = Client(cluster) - main(datadir=Path(args.outdir), stokeslist=None, verbose=verbose) - client.close() - cluster.close() - if __name__ == "__main__": cli() diff --git a/arrakis/configs/default.yaml b/arrakis/configs/default.yaml index ca734acc..7e064aee 100644 --- a/arrakis/configs/default.yaml +++ b/arrakis/configs/default.yaml @@ -1,25 +1,13 @@ -# Set up for Magnus -cores: 24 -processes: 12 -name: 'spice-worker' -memory: "60GB" -project: 'ja3' -queue: 'workq' -n_workers: 1000 -walltime: '6:00:00' -job_extra: ['-M magnus'] -# interface for the workers -interface: "ipogif0" -log_directory: 'spice_logs' -env_extra: [ - 'export OMP_NUM_THREADS=1', - 'source /home/$(whoami)/.bashrc', - 'conda activate spice' -] -python: 'srun -n 1 -c 24 python' -extra: [ - "--lifetime", "11h", - "--lifetime-stagger", "5m", -] -death_timeout: 300 -local_directory: '/dev/shm' +# Set up for local mahine +cluster_class: "distributed.LocalCluster" +cluster_kwargs: + cores: 1 + processes: 1 + name: 'spice-worker' + memory: "8GB" +adapt_kwargs: + minimum: 1 + maximum: 8 + wait_count: 20 + target_duration: "300s" + interval: "30s" diff --git a/arrakis/configs/galaxy.yaml b/arrakis/configs/galaxy.yaml deleted file mode 100644 index 1077a2ab..00000000 --- a/arrakis/configs/galaxy.yaml +++ /dev/null @@ -1,25 +0,0 @@ -# Set up for Galaxy -cores: 20 -processes: 20 -name: 'spice-worker' -memory: "60GB" -project: 'askaprt' -queue: 'workq' -walltime: '6:00:00' -n_workers: 1000 -job_extra: ['-M galaxy'] -# interface for the workers -interface: "ipogif0" -log_directory: 'spice_logs' -env_extra: [ - 'export OMP_NUM_THREADS=1', - 'source /home/$(whoami)/.bashrc', - 'conda activate spice' -] -python: 'srun -n 1 -c 20 python' -extra: [ - "--lifetime", "5h", - "--lifetime-stagger", "5m", -] -death_timeout: 300 -local_directory: '/dev/shm' diff --git a/arrakis/configs/magnus.yaml b/arrakis/configs/magnus.yaml deleted file mode 100644 index ca734acc..00000000 --- a/arrakis/configs/magnus.yaml +++ /dev/null @@ -1,25 +0,0 @@ -# Set up for Magnus -cores: 24 -processes: 12 -name: 'spice-worker' -memory: "60GB" -project: 'ja3' -queue: 'workq' -n_workers: 1000 -walltime: '6:00:00' -job_extra: ['-M magnus'] -# interface for the workers -interface: "ipogif0" -log_directory: 'spice_logs' -env_extra: [ - 'export OMP_NUM_THREADS=1', - 'source /home/$(whoami)/.bashrc', - 'conda activate spice' -] -python: 'srun -n 1 -c 24 python' -extra: [ - "--lifetime", "11h", - "--lifetime-stagger", "5m", -] -death_timeout: 300 -local_directory: '/dev/shm' diff --git a/arrakis/configs/petrichor.yaml b/arrakis/configs/petrichor.yaml index 6f4a51d1..9f79823a 100644 --- a/arrakis/configs/petrichor.yaml +++ b/arrakis/configs/petrichor.yaml @@ -1,28 +1,26 @@ # Set up for Petrichor -cores: 8 -processes: 8 -name: 'spice-worker' -memory: "64GiB" -account: 'OD-217087' -#queue: 'workq' -walltime: '0-8:00:00' -job_extra_directives: ['--qos express'] -# interface for the workers -interface: "ib0" -log_directory: 'spice_logs' -job_script_prologue: [ - 'module load singularity', -] -# job_script_prologue: [ -# 'export OMP_NUM_THREADS=1', -# 'source /home/$(whoami)/.bashrc', -# 'conda activate spice' -# ] -# python: 'srun -n 1 -c 64 python' -#worker_extra_args: [ -# "--lifetime", "23h", -# "--lifetime-stagger", "5m", -#] -death_timeout: 1000 -local_directory: $LOCALDIR -silence_logs: 'info' +cluster_class: "dask_jobqueue.SLURMCluster" +cluster_kwargs: + cores: 8 + processes: 8 + name: 'spice-worker' + memory: "64GiB" + account: 'OD-217087' + #queue: 'workq' + walltime: '0-01:00:00' + job_extra_directives: ['--qos express'] + # interface for the workers + interface: "ib0" + log_directory: 'spice_logs' + job_script_prologue: [ + 'module load singularity', + 'unset SINGULARITY_BINDPATH' + ] + local_directory: $LOCALDIR + silence_logs: 'info' +adapt_kwargs: + minimum: 1 + maximum: 36 + wait_count: 20 + target_duration: "300s" + interval: "30s" diff --git a/arrakis/cutout.py b/arrakis/cutout.py index 7c0bc9bf..92b4a231 100644 --- a/arrakis/cutout.py +++ b/arrakis/cutout.py @@ -7,20 +7,20 @@ from glob import glob from pprint import pformat from shutil import copyfile -from typing import Dict, List, Union +from typing import Dict, List +from typing import NamedTuple as Struct +from typing import Optional, TypeVar, Union import astropy.units as u import numpy as np import pymongo -from astropy import units as u from astropy.coordinates import Latitude, Longitude, SkyCoord from astropy.io import fits from astropy.utils import iers from astropy.utils.exceptions import AstropyWarning from astropy.wcs.utils import skycoord_to_pixel -from dask import delayed from dask.distributed import Client, LocalCluster -from distributed import get_client +from prefect import flow, task, unmapped from spectral_cube import SpectralCube from spectral_cube.utils import SpectralCubeWarning @@ -28,7 +28,7 @@ from arrakis.utils.database import get_db, test_db from arrakis.utils.fitsutils import fix_header from arrakis.utils.io import try_mkdir -from arrakis.utils.pipeline import chunk_dask, logo_str, tqdm_dask +from arrakis.utils.pipeline import logo_str iers.conf.auto_download = False warnings.filterwarnings( @@ -41,21 +41,37 @@ logger.setLevel(logging.INFO) - -@delayed +T = TypeVar("T") + + +class CutoutArgs(Struct): + """Arguments for cutout function""" + + image: str + """Name of the image file""" + source_id: str + """Name of the source""" + ra_high: float + """Upper RA bound in degrees""" + ra_low: float + """Lower RA bound in degrees""" + dec_high: float + """Upper DEC bound in degrees""" + dec_low: float + """Lower DEC bound in degrees""" + outdir: str + """Output directory""" + beam: int + """Beam number""" + stoke: str + """Stokes parameter""" + + +@task(name="Cutout island") def cutout( - image: str, - src_name: str, - beam: int, - ra_hi: float, - ra_lo: float, - dec_hi: float, - dec_lo: float, - outdir: str, - stoke: str, + cutout_args: CutoutArgs, field: str, pad=3, - verbose=False, dryrun=False, ) -> List[pymongo.UpdateOne]: """Perform a cutout. @@ -64,10 +80,10 @@ def cutout( image (str): Name of the image file src_name (str): Name of the RACS source beam (int): Beam number - ra_hi (float): Upper RA bound - ra_lo (float): Lower RA bound - dec_hi (float): Upper DEC bound - dec_lo (float): Lower DEC bound + ra_high (float): Upper RA bound + ra_low (float): Lower RA bound + dec_high (float): Upper DEC bound + dec_low (float): Lower DEC bound outdir (str): Output directgory stoke (str): Stokes parameter field (str): RACS field name @@ -79,23 +95,19 @@ def cutout( pymongo.UpdateOne: Update query for MongoDB """ logger.setLevel(logging.INFO) - # logger = logging.getLogger('distributed.worker') - # logger = get_run_logger() - - logger.info(f"Timwashere - {image=}") - outdir = os.path.abspath(outdir) + outdir = os.path.abspath(cutout_args.outdir) ret = [] for imtype in ["image", "weight"]: - basename = os.path.basename(image) - outname = f"{src_name}.cutout.{basename}" + basename = os.path.basename(cutout_args.image) + outname = f"{cutout_args.source_id}.cutout.{basename}" outfile = os.path.join(outdir, outname) if imtype == "weight": - image = image.replace("image.restored", "weights.restored").replace( - ".fits", ".txt" - ) + image = cutout_args.image.replace( + "image.restored", "weights.restored" + ).replace(".fits", ".txt") outfile = outfile.replace("image.restored", "weights.restored").replace( ".fits", ".txt" ) @@ -103,16 +115,16 @@ def cutout( logger.info(f"Written to {outfile}") if imtype == "image": - logger.info(f"Reading {image}") + logger.info(f"Reading {cutout_args.image}") with warnings.catch_warnings(): warnings.simplefilter("ignore", AstropyWarning) - cube = SpectralCube.read(image) + cube = SpectralCube.read(cutout_args.image) padder = cube.header["BMAJ"] * u.deg * pad - xlo = Longitude(ra_lo * u.deg) - Longitude(padder) - xhi = Longitude(ra_hi * u.deg) + Longitude(padder) - ylo = Latitude(dec_lo * u.deg) - Latitude(padder) - yhi = Latitude(dec_hi * u.deg) + Latitude(padder) + xlo = Longitude(cutout_args.ra_low * u.deg) - Longitude(padder) + xhi = Longitude(cutout_args.ra_high * u.deg) + Longitude(padder) + ylo = Latitude(cutout_args.dec_low * u.deg) - Latitude(padder) + yhi = Latitude(cutout_args.dec_high * u.deg) + Latitude(padder) xp_lo, yp_lo = skycoord_to_pixel(SkyCoord(xlo, ylo), cube.wcs) xp_hi, yp_hi = skycoord_to_pixel(SkyCoord(xhi, yhi), cube.wcs) @@ -128,7 +140,9 @@ def cutout( new_header = cutout_cube.header with warnings.catch_warnings(): warnings.simplefilter("ignore", AstropyWarning) - with fits.open(image, memmap=True, mode="denywrite") as hdulist: + with fits.open( + cutout_args.image, memmap=True, mode="denywrite" + ) as hdulist: data = hdulist[0].data old_header = hdulist[0].header @@ -140,7 +154,7 @@ def cutout( ] fixed_header = fix_header(new_header, old_header) # Add source name to header for CASDA - fixed_header["OBJECT"] = src_name + fixed_header["OBJECT"] = cutout_args.source_id if not dryrun: fits.writeto( outfile, @@ -152,13 +166,15 @@ def cutout( logger.info(f"Written to {outfile}") # Update database - myquery = {"Source_ID": src_name} + myquery = {"Source_ID": cutout_args.source_id} filename = os.path.join( os.path.basename(os.path.dirname(outfile)), os.path.basename(outfile) ) newvalues = { - "$set": {f"beams.{field}.{stoke}_beam{beam}_{imtype}_file": filename} + "$set": { + f"beams.{field}.{cutout_args.stoke}_beam{cutout_args.beam}_{imtype}_file": filename + } } ret += [pymongo.UpdateOne(myquery, newvalues, upsert=True)] @@ -166,7 +182,7 @@ def cutout( return ret -@delayed +@task(name="Get cutout arguments") def get_args( island: Dict, comps: List[Dict], @@ -177,7 +193,7 @@ def get_args( datadir: str, stokeslist: List[str], verbose=True, -) -> List[Dict]: +) -> Union[List[CutoutArgs], None]: """Get arguments for cutout function Args: @@ -196,7 +212,7 @@ def get_args( Exception: Problems with coordinates Returns: - List[Dict]: List of cutout arguments for cutout function + List[CutoutArgs]: List of cutout arguments for cutout function """ logger.setLevel(logging.INFO) @@ -204,6 +220,10 @@ def get_args( assert island["Source_ID"] == island_id assert beam["Source_ID"] == island_id + if len(comps) == 0: + logger.warning(f"Skipping island {island_id} -- no components found") + return None + beam_list = list(set(beam["beams"][field]["beam_list"])) outdir = f"{outdir}/{island['Source_ID']}" @@ -226,28 +246,28 @@ def get_args( ra_max = np.max(coords.ra) ra_i_max = np.argmax(coords.ra) ra_off = Longitude(majs[ra_i_max]) - ra_hi = ra_max + ra_off + ra_high = ra_max + ra_off ra_min = np.min(coords.ra) ra_i_min = np.argmin(coords.ra) ra_off = Longitude(majs[ra_i_min]) - ra_lo = ra_min - ra_off + ra_low = ra_min - ra_off dec_max = np.max(coords.dec) dec_i_max = np.argmax(coords.dec) dec_off = Longitude(majs[dec_i_max]) - dec_hi = dec_max + dec_off + dec_high = dec_max + dec_off dec_min = np.min(coords.dec) dec_i_min = np.argmin(coords.dec) dec_off = Longitude(majs[dec_i_min]) - dec_lo = dec_min - dec_off + dec_low = dec_min - dec_off except Exception as e: logger.debug(f"coords are {coords=}") logger.debug(f"comps are {comps=}") raise e - args = [] + args: List[CutoutArgs] = [] for beam_num in beam_list: for stoke in stokeslist: wild = f"{datadir}/image.restored.{stoke.lower()}*contcube*beam{beam_num:02}.conv.fits" @@ -260,25 +280,23 @@ def get_args( ) for image in images: - args.extend( - [ - { - "image": image, - "id": island["Source_ID"], - "ra_hi": ra_hi.deg, - "ra_lo": ra_lo.deg, - "dec_hi": dec_hi.deg, - "dec_lo": dec_lo.deg, - "outdir": outdir, - "beam": beam_num, - "stoke": stoke.lower(), - } - ] + args.append( + CutoutArgs( + image=image, + source_id=island["Source_ID"], + ra_high=ra_high.deg, + ra_low=ra_low.deg, + dec_high=dec_high.deg, + dec_low=dec_low.deg, + outdir=outdir, + beam=beam_num, + stoke=stoke.lower(), + ) ) return args -@delayed +@task(name="Find components") def find_comps(island_id: str, comp_col: pymongo.collection.Collection) -> List[Dict]: """Find components for a given island @@ -293,34 +311,44 @@ def find_comps(island_id: str, comp_col: pymongo.collection.Collection) -> List[ return comps -@delayed -def unpack(list_sq: List[List[Dict]]) -> List[Dict]: - """Unpack list of lists +@task(name="Unpack list") +def unpack(list_sq: List[Union[List[T], None, T]]) -> List[T]: + """Unpack list of lists of things into a list of things + Skips None entries Args: - list_sq (List[List[Dict]]): List of lists of dicts + list_sq (List[List[T] | None]): List of lists of things or Nones Returns: - List[Dict]: List of dicts + List[T]: List of things """ - list_fl = [] + logger.debug(f"{list_sq=}") + list_fl: List[T] = [] for i in list_sq: - for j in i: - list_fl.append(j) + if i is None: + continue + elif isinstance(i, list): + list_fl.extend(i) + continue + else: + list_fl.append(i) + logger.debug(f"{list_fl=}") return list_fl +@flow(name="Cutout islands") def cutout_islands( field: str, directory: str, host: str, epoch: int, - username: Union[str, None] = None, - password: Union[str, None] = None, + username: Optional[str] = None, + password: Optional[str] = None, pad: float = 3, - stokeslist: Union[List[str], None] = None, + stokeslist: Optional[List[str]] = None, verbose_worker: bool = False, dryrun: bool = True, + limit: Optional[int] = None, ) -> None: """Perform cutouts of RACS islands in parallel. @@ -328,7 +356,6 @@ def cutout_islands( field (str): RACS field name. directory (str): Directory to store cutouts. host (str): MongoDB host. - client (Client): Dask client. username (str, optional): Mongo username. Defaults to None. password (str, optional): Mongo password. Defaults to None. verbose (bool, optional): Verbose output. Defaults to True. @@ -339,8 +366,7 @@ def cutout_islands( """ if stokeslist is None: stokeslist = ["I", "Q", "U", "V"] - client = get_client() - logger.debug(f"Client is {client}") + directory = os.path.abspath(directory) outdir = os.path.join(directory, "cutouts") @@ -377,55 +403,36 @@ def cutout_islands( # Create output dir if it doesn't exist try_mkdir(outdir) - args = [] - for island_id, island, comp, beam in zip(island_ids, islands, comps, beams): - if len(comp) == 0: - warnings.warn(f"Skipping island {island_id} -- no components found") - continue - else: - arg = get_args( - island, - comp, - beam, - island_id, - outdir, - field, - directory, - stokeslist, - verbose=verbose_worker, - ) - args.append(arg) - - flat_args = unpack(args) - flat_args = client.compute(flat_args) - tqdm_dask(flat_args, desc="Getting args", total=len(islands) + 1) - flat_args = flat_args.result() - cuts = [] - for arg in flat_args: - cut = cutout( - image=arg["image"], - src_name=arg["id"], - beam=arg["beam"], - ra_hi=arg["ra_hi"], - ra_lo=arg["ra_lo"], - dec_hi=arg["dec_hi"], - dec_lo=arg["dec_lo"], - outdir=arg["outdir"], - stoke=arg["stoke"], - field=field, - pad=pad, - verbose=verbose_worker, - dryrun=dryrun, - ) - cuts.append(cut) - - futures = chunk_dask( - outputs=cuts, - task_name="cutouts", - progress_text="Cutting out", + if limit is not None: + logger.critical(f"Limiting to {limit} islands") + islands = islands[:limit] + island_ids = island_ids[:limit] + comps = comps[:limit] + beams = beams[:limit] + + args = get_args.map( + island=islands, + comps=comps, + beam=beams, + island_id=island_ids, + outdir=unmapped(outdir), + field=unmapped(field), + datadir=unmapped(directory), + stokeslist=unmapped(stokeslist), + verbose=unmapped(verbose_worker), + ) + # args = [a.result() for a in args] + # flat_args = unpack.map(args) + flat_args = unpack.submit(args) + cuts = cutout.map( + cutout_args=flat_args, + field=unmapped(field), + pad=unmapped(pad), + dryrun=unmapped(dryrun), ) + if not dryrun: - _updates = [f.compute() for f in futures] + _updates = [f.result() for f in cuts] updates = [val for sublist in _updates for val in sublist] logger.info("Updating database...") db_res = beams_col.bulk_write(updates, ordered=False) @@ -434,7 +441,7 @@ def cutout_islands( logger.info("Cutouts Done!") -def main(args: argparse.Namespace, verbose=True) -> None: +def main(args: argparse.Namespace) -> None: """Main script Args: @@ -452,6 +459,7 @@ def main(args: argparse.Namespace, verbose=True) -> None: stokeslist=args.stokeslist, verbose_worker=args.verbose_worker, dryrun=args.dryrun, + limit=args.limit, ) logger.info("Done!") @@ -541,6 +549,12 @@ def cutout_parser(parent_parser: bool = False) -> argparse.ArgumentParser: type=str, help="List of Stokes parameters to image [ALL]", ) + parser.add_argument( + "--limit", + type=Optional[int], + default=None, + help="Limit number of islands to process [None]", + ) return cut_parser @@ -555,12 +569,6 @@ def cli() -> None: if verbose: logger.setLevel(logging.INFO) - cluster = LocalCluster( - n_workers=12, threads_per_worker=1, dashboard_address=":9898" - ) - client = Client(cluster) - logger.info(client) - test_db( host=args.host, username=args.username, diff --git a/arrakis/frion.py b/arrakis/frion.py index 5d185f25..392fc08f 100644 --- a/arrakis/frion.py +++ b/arrakis/frion.py @@ -2,34 +2,41 @@ """Correct for the ionosphere in parallel""" import logging import os -import time from glob import glob from pathlib import Path from pprint import pformat from shutil import copyfile -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, List +from typing import NamedTuple as Struct +from typing import Optional, Union import astropy.units as u -import dask import numpy as np import pymongo from astropy.time import Time, TimeDelta -from dask import delayed from dask.distributed import Client, LocalCluster from FRion import correct, predict +from prefect import flow, task, unmapped from arrakis.logger import logger from arrakis.utils.database import get_db, get_field_db, test_db from arrakis.utils.fitsutils import getfreq from arrakis.utils.io import try_mkdir -from arrakis.utils.pipeline import logo_str, tqdm_dask +from arrakis.utils.pipeline import logo_str logger.setLevel(logging.INFO) -@delayed +class Prediction(Struct): + """FRion prediction""" + + predict_file: str + update: pymongo.UpdateOne + + +@task(name="FRion correction") def correct_worker( - beam: Dict, outdir: str, field: str, predict_file: str, island_id: str + beam: Dict, outdir: str, field: str, prediction: Prediction, island: dict ) -> pymongo.UpdateOne: """Apply FRion corrections to a single island @@ -43,6 +50,8 @@ def correct_worker( Returns: pymongo.UpdateOne: Pymongo update query """ + predict_file = prediction.predict_file + island_id = island["Source_ID"] qfile = os.path.join(outdir, beam["beams"][field]["q_file"]) ufile = os.path.join(outdir, beam["beams"][field]["u_file"]) @@ -67,7 +76,7 @@ def correct_worker( return pymongo.UpdateOne(myquery, newvalues) -@delayed(nout=2) +@task(name="FRion predction") def predict_worker( island: Dict, field: str, @@ -82,7 +91,7 @@ def predict_worker( formatter: Optional[Union[str, Callable]] = None, proxy_server: Optional[str] = None, pre_download: bool = False, -) -> Tuple[str, pymongo.UpdateOne]: +) -> Prediction: """Make FRion prediction for a single island Args: @@ -135,11 +144,15 @@ def predict_worker( predict_file = os.path.join(i_dir, f"{iname}_ion.txt") predict.write_modulation(freq_array=freq, theta=theta, filename=predict_file) - plot_file = os.path.join(i_dir, f"{iname}_ion.pdf") - predict.generate_plots( - times, RMs, theta, freq, position=[ra, dec], savename=plot_file - ) - plot_files = glob(os.path.join(i_dir, "*ion.pdf")) + plot_file = os.path.join(i_dir, f"{iname}_ion.png") + try: + predict.generate_plots( + times, RMs, theta, freq, position=[ra, dec], savename=plot_file + ) + except Exception as e: + logger.error(f"Failed to generate plot: {e}") + + plot_files = glob(os.path.join(i_dir, "*ion.png")) logger.info(f"Plotting files: {plot_files=}") for src in plot_files: base = os.path.basename(src) @@ -163,9 +176,18 @@ def predict_worker( update = pymongo.UpdateOne(myquery, newvalues) - return predict_file, update + return Prediction(predict_file, update) + +@task(name="Index beams") +def index_beams(island: dict, beams: List[dict]) -> dict: + island_id = island["Source_ID"] + beam_idx = [i for i, b in enumerate(beams) if b["Source_ID"] == island_id][0] + beam = beams[beam_idx] + return beam + +@flow(name="FRion") def main( field: str, outdir: Path, @@ -174,12 +196,12 @@ def main( username: Optional[str] = None, password: Optional[str] = None, database=False, - verbose=True, ionex_server: str = "ftp://ftp.aiub.unibe.ch/CODE/", ionex_prefix: str = "codg", ionex_proxy_server: Optional[str] = None, ionex_formatter: Optional[Union[str, Callable]] = "ftp.aiub.unibe.ch", ionex_predownload: bool = False, + limit: Optional[int] = None, ): """Main script @@ -238,61 +260,52 @@ def main( freq = getfreq( os.path.join(cutdir, f"{beams[0]['beams'][f'{field}']['q_file']}"), - ) # Type: u.Quantity + ) - # Loop over islands in parallel - outputs = [] - updates_arrays = [] + if limit is not None: + logger.info(f"Limiting to {limit} islands") + islands = islands[:limit] + + beams_cor = [] for island in islands: island_id = island["Source_ID"] beam_idx = [i for i, b in enumerate(beams) if b["Source_ID"] == island_id][0] beam = beams[beam_idx] - # Get FRion predictions - predict_file, update = predict_worker( - island=island, - field=field, - beam=beam, - start_time=start_time, - end_time=end_time, - freq=freq.to(u.Hz).value, - cutdir=cutdir, - plotdir=plotdir, - server=ionex_server, - prefix=ionex_prefix, - proxy_server=ionex_proxy_server, - formatter=ionex_formatter, - pre_download=ionex_predownload, - ) - updates_arrays.append(update) - # Apply FRion predictions - output = correct_worker( - beam=beam, - outdir=cutdir, - field=field, - predict_file=predict_file, - island_id=island_id, - ) - outputs.append(output) - # Wait for IONEX data I guess... - _ = outputs[0].compute() - time.sleep(10) - # Execute - futures, future_arrays = dask.persist(outputs, updates_arrays) - # dumb solution for https://github.com/dask/distributed/issues/4831 - time.sleep(10) - tqdm_dask( - futures, desc="Running FRion", disable=(not verbose), total=len(islands) * 3 + beams_cor.append(beam) + + predictions = predict_worker.map( + island=islands, + field=unmapped(field), + beam=beams_cor, + start_time=unmapped(start_time), + end_time=unmapped(end_time), + freq=unmapped(freq.to(u.Hz).value), + cutdir=unmapped(cutdir), + plotdir=unmapped(plotdir), + server=unmapped(ionex_server), + prefix=unmapped(ionex_prefix), + proxy_server=unmapped(ionex_proxy_server), + formatter=unmapped(ionex_formatter), + pre_download=unmapped(ionex_predownload), + ) + + corrections = correct_worker.map( + beam=beams_cor, + outdir=unmapped(cutdir), + field=unmapped(field), + prediction=predictions, + island=islands, ) + + updates_arrays = [p.result().update for p in predictions] + updates = [c.result() for c in corrections] if database: logger.info("Updating beams database...") - updates = [f.compute() for f in futures] db_res = beams_col.bulk_write(updates, ordered=False) logger.info(pformat(db_res.bulk_api_result)) logger.info("Updating island database...") - updates_arrays_cmp = [f.compute() for f in future_arrays] - - db_res = island_col.bulk_write(updates_arrays_cmp, ordered=False) + db_res = island_col.bulk_write(updates_arrays, ordered=False) logger.info(pformat(db_res.bulk_api_result)) @@ -405,12 +418,6 @@ def cli(): if verbose: logger.setLevel(logging.INFO) - cluster = LocalCluster( - n_workers=10, processes=True, threads_per_worker=1, local_directory="/dev/shm" - ) - client = Client(cluster) - logger.info(client) - test_db(host=args.host, username=args.username, password=args.password) main( diff --git a/arrakis/imager.py b/arrakis/imager.py index c7cb1f87..06997354 100644 --- a/arrakis/imager.py +++ b/arrakis/imager.py @@ -914,48 +914,34 @@ def imager_parser(parent_parser: bool = False) -> argparse.ArgumentParser: def cli(): """Command-line interface""" parser = imager_parser() - args = parser.parse_args() - - if args.mpi: - initialize(interface="ipogif0") - cluster = None - - else: - cluster = LocalCluster( - threads_per_worker=1, - ) - - with Client(cluster) as client: - logger.debug(f"{cluster=}") - logger.debug(f"{client=}") - main( - msdir=args.msdir, - out_dir=args.outdir, - cutoff=args.psf_cutoff, - robust=args.robust, - pols=args.pols, - nchan=args.nchan, - size=args.size, - scale=args.scale, - mgain=args.mgain, - niter=args.niter, - auto_mask=args.auto_mask, - force_mask_rounds=args.force_mask_rounds, - auto_threshold=args.auto_threshold, - minuv=args.minuv, - purge=args.purge, - taper=args.taper, - parallel_deconvolution=args.parallel, - gridder=args.gridder, - wsclean_path=Path(args.local_wsclean) - if args.local_wsclean - else args.hosted_wsclean, - multiscale=args.multiscale, - ms_glob_pattern=args.ms_glob_pattern, - data_column=args.data_column, - skip_fix_ms=args.skip_fix_ms, - ) + main( + msdir=args.msdir, + out_dir=args.outdir, + cutoff=args.psf_cutoff, + robust=args.robust, + pols=args.pols, + nchan=args.nchan, + size=args.size, + scale=args.scale, + mgain=args.mgain, + niter=args.niter, + auto_mask=args.auto_mask, + force_mask_rounds=args.force_mask_rounds, + auto_threshold=args.auto_threshold, + minuv=args.minuv, + purge=args.purge, + taper=args.taper, + parallel_deconvolution=args.parallel, + gridder=args.gridder, + wsclean_path=Path(args.local_wsclean) + if args.local_wsclean + else args.hosted_wsclean, + multiscale=args.multiscale, + ms_glob_pattern=args.ms_glob_pattern, + data_column=args.data_column, + skip_fix_ms=args.skip_fix_ms, + ) if __name__ == "__main__": diff --git a/arrakis/linmos.py b/arrakis/linmos.py index 1c0ca380..e68cd110 100644 --- a/arrakis/linmos.py +++ b/arrakis/linmos.py @@ -7,19 +7,20 @@ from glob import glob from pathlib import Path from pprint import pformat -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List +from typing import NamedTuple as Struct +from typing import Optional import pymongo from astropy.utils.exceptions import AstropyWarning -from dask import delayed from dask.distributed import Client, LocalCluster +from prefect import flow, task, unmapped from racs_tools import beamcon_3D from spectral_cube.utils import SpectralCubeWarning from spython.main import Client as sclient from arrakis.logger import logger from arrakis.utils.database import get_db, test_db -from arrakis.utils.pipeline import chunk_dask warnings.filterwarnings(action="ignore", category=SpectralCubeWarning, append=True) warnings.simplefilter("ignore", category=AstropyWarning) @@ -29,19 +30,26 @@ logger.setLevel(logging.INFO) -@delayed(nout=2) +class ImagePaths(Struct): + """Class to hold image paths""" + + images: List[Path] + """List of image paths""" + weights: List[Path] + """List of weight paths""" + + +@task(name="Find images") def find_images( field: str, - src_name: str, beams: dict, stoke: str, datadir: Path, -) -> Tuple[List[Path], List[Path]]: +) -> ImagePaths: """Find the images and weights for a given field and stokes parameter Args: field (str): Field name. - src_name (str): Source name. beams (dict): Beam information. stoke (str): Stokes parameter. datadir (Path): Data directory. @@ -50,9 +58,11 @@ def find_images( Exception: If no files are found. Returns: - Tuple[List[Path], List[Path]]: List of images and weights. + ImagePaths: List of images and weights. """ logger.setLevel(logging.INFO) + + src_name = beams["Source_ID"] field_beams = beams["beams"][field] # First check that the images exist @@ -82,29 +92,29 @@ def find_images( im.parent.name == wt.parent.name ), "Image and weight are in different areas!" - return image_list, weight_list + return ImagePaths(image_list, weight_list) -@delayed +@task(name="Smooth images") def smooth_images( - image_dict: Dict[str, List[Path]], -) -> Dict[str, List[Path]]: + image_dict: Dict[str, ImagePaths], +) -> Dict[str, ImagePaths]: """Smooth cubelets to a common resolution Args: - image_list (List[Path]): List of cubelets to smooth. + image_list (ImagePaths): List of cubelets to smooth. Returns: - List[Path]: Smoothed cubelets. + ImagePaths: Smoothed cubelets. """ - smooth_dict: Dict[str, List[Path]] = {} + smooth_dict: Dict[str, ImagePaths] = {} for stoke, image_list in image_dict.items(): infiles: List[str] = [] - for im in image_list: + for im in image_list.images: if im.suffix == ".fits": infiles.append(im.resolve().as_posix()) datadict = beamcon_3D.main( - infile=[im.resolve().as_posix() for im in image_list], + infile=[im.resolve().as_posix() for im in image_list.images], uselogs=False, mode="total", conv_mode="robust", @@ -113,15 +123,14 @@ def smooth_images( smooth_files: List[Path] = [] for key, val in datadict.items(): smooth_files.append(Path(val["outfile"])) - smooth_dict[stoke] = smooth_files + smooth_dict[stoke] = ImagePaths(smooth_files, image_list.weights) return smooth_dict -@delayed +@task(name="Generate parset") def genparset( - image_list: List[Path], - weight_list: List[Path], + image_paths: ImagePaths, stoke: str, datadir: Path, holofile: Optional[Path] = None, @@ -129,12 +138,10 @@ def genparset( """Generate parset for LINMOS Args: - field (str): RACS field name. - src_name (str): RACE source name. - beams (dict): Mongo entry for RACS beams. + image_paths (ImagePaths): List of images and weights. stoke (str): Stokes parameter. - datadir (str): Data directory. - holofile (str): Full path to holography file. + datadir (Path): Data directory. + holofile (Path, optional): Path to the holography file to include in the bind list. Defaults to None. Raises: Exception: If no files are found. @@ -144,17 +151,13 @@ def genparset( """ logger.setLevel(logging.INFO) - image_string = ( - f"[{','.join([im.resolve().with_suffix('').as_posix() for im in image_list])}]" - ) - weight_string = ( - f"[{','.join([im.resolve().with_suffix('').as_posix() for im in weight_list])}]" - ) + image_string = f"[{','.join([im.resolve().with_suffix('').as_posix() for im in image_paths.images])}]" + weight_string = f"[{','.join([im.resolve().with_suffix('').as_posix() for im in image_paths.weights])}]" - parset_dir = datadir.resolve() / image_list[0].parent.name + parset_dir = datadir.resolve() / image_paths.images[0].parent.name - first_image = image_list[0].resolve().with_suffix("").as_posix() - first_weight = weight_list[0].resolve().with_suffix("").as_posix() + first_image = image_paths.images[0].resolve().with_suffix("").as_posix() + first_weight = image_paths.weights[0].resolve().with_suffix("").as_posix() linmos_image_str = f"{first_image[:first_image.find('beam')]}linmos" linmos_weight_str = f"{first_weight[:first_weight.find('beam')]}linmos" @@ -187,17 +190,17 @@ def genparset( return parset_file -@delayed +@task(name="Run linmos") def linmos( - parset: str, fieldname: str, image: str, holofile: Path -) -> pymongo.UpdateOne: + parset: Optional[str], fieldname: str, image: str, holofile: Path +) -> Optional[pymongo.UpdateOne]: """Run linmos Args: parset (str): Path to parset file. fieldname (str): Name of RACS field. image (str): Name of Yandasoft image. - holofile (Union[Path,str]): Path to the holography file to include in the bind list. + holofile (Path): Path to the holography file to include in the bind list. verbose (bool, optional): Verbose output. Defaults to False. Raises: @@ -209,9 +212,11 @@ def linmos( """ logger.setLevel(logging.INFO) + if parset is None: + return + workdir = os.path.dirname(parset) rootdir = os.path.split(workdir)[0] - junk = os.path.split(workdir)[-1] parset_name = os.path.basename(parset) source = os.path.basename(workdir) stoke = parset_name[parset_name.find(".in") - 1] @@ -230,7 +235,6 @@ def linmos( outstr = "\n".join(output["message"]) with open(log_file, "w") as f: f.write(outstr) - # f.write(output['message']) if output["return_code"] != 0: raise Exception(f"LINMOS failed! Check '{log_file}'") @@ -267,6 +271,7 @@ def get_yanda(version="1.3.0") -> str: return image +@flow(name="LINMOS") def main( field: str, datadir: Path, @@ -277,8 +282,8 @@ def main( password: Optional[str] = None, yanda: str = "1.3.0", yanda_img: Optional[Path] = None, - stokeslist: Union[List[str], None] = None, - verbose=True, + stokeslist: Optional[List[str]] = None, + limit: Optional[int] = None, ) -> None: """Main script @@ -313,15 +318,14 @@ def main( logger.info(f"The query is {query=}") - island_ids = sorted(beams_col.distinct("Source_ID", query)) - big_beams = list( + island_ids: List[str] = sorted(beams_col.distinct("Source_ID", query)) + big_beams: List[dict] = list( beams_col.find({"Source_ID": {"$in": island_ids}}).sort("Source_ID") ) - # files = sorted([name for name in glob(f"{cutdir}/*") if os.path.isdir(os.path.join(cutdir, name))]) - big_comps = list( + big_comps: List[dict] = list( comp_col.find({"Source_ID": {"$in": island_ids}}).sort("Source_ID") ) - comps = [] + comps: List[List[dict]] = [] for island_id in island_ids: _comps = [] for c in big_comps: @@ -331,49 +335,36 @@ def main( assert len(big_beams) == len(comps) - parfiles = [] - for beams, comp in zip(big_beams, comps): - src = beams["Source_ID"] - if len(comp) == 0: - warnings.warn(f"Skipping island {src} -- no components found") - continue - else: - image_dict: Dict[str, List[Path]] = {} - weight_dict: Dict[str, List[Path]] = {} - for stoke in stokeslist: - image_list, weight_list = find_images( - field=field, - src_name=src, - beams=beams, - stoke=stoke.capitalize(), - datadir=cutdir, - ) - image_dict[stoke] = image_list - weight_dict[stoke] = weight_list - - smooth_dict = smooth_images(image_dict) - for stoke in stokeslist: - parfile = genparset( - image_list=smooth_dict[stoke], - weight_list=weight_dict[stoke], - stoke=stoke.capitalize(), - datadir=cutdir, - holofile=holofile, - ) - parfiles.append(parfile) - - results = [] - for parset in parfiles: - results.append(linmos(parset, field, str(image), holofile=holofile)) - - futures = chunk_dask( - outputs=results, - task_name="LINMOS", - progress_text="Runing LINMOS", - verbose=verbose, + if limit is not None: + logger.critical(f"Limiting to {limit} islands") + big_beams = big_beams[:limit] + comps = comps[:limit] + + all_parfiles = [] + for stoke in stokeslist: + image_paths = find_images.map( + field=unmapped(field), + beams=big_beams, + stoke=unmapped(stoke.capitalize()), + datadir=unmapped(cutdir), + ) + parfiles = genparset.map( + image_paths=image_paths, + stoke=unmapped(stoke.capitalize()), + datadir=unmapped(cutdir), + holofile=unmapped(holofile), + ) + all_parfiles.extend(parfiles) + + results = linmos.map( + all_parfiles, + unmapped(field), + unmapped(str(image)), + unmapped(holofile), ) - updates = [f.compute() for f in futures] + updates = [f.result() for f in results] + updates = [u for u in updates if u is not None] logger.info("Updating database...") db_res = beams_col.bulk_write(updates, ordered=False) logger.info(pformat(db_res.bulk_api_result)) @@ -458,12 +449,15 @@ def cli(): parser.add_argument( "--password", type=str, default=None, help="Password of mongodb." ) + parser.add_argument( + "--limit", + type=Optional[int], + default=None, + help="Limit the number of islands to process.", + ) args = parser.parse_args() - cluster = LocalCluster(n_workers=1) - client = Client(cluster) - verbose = args.verbose test_db( host=args.host, username=args.username, password=args.password, verbose=verbose @@ -480,7 +474,7 @@ def cli(): yanda=args.yanda, yanda_img=args.yanda_image, stokeslist=args.stokeslist, - verbose=verbose, + limit=args.limit, ) diff --git a/arrakis/makecat.py b/arrakis/makecat.py index 9ecae793..1283d627 100644 --- a/arrakis/makecat.py +++ b/arrakis/makecat.py @@ -17,6 +17,7 @@ from astropy.stats import sigma_clip from astropy.table import Column, Table from dask.diagnostics import ProgressBar +from prefect import flow, task from rmtable import RMTable from scipy.stats import lognorm, norm from tqdm import tqdm @@ -218,6 +219,7 @@ def lognorm_from_percentiles(x1, p1, x2, p2): return scale, np.exp(mean) +@task(name="Fix sigma_add") def sigma_add_fix(tab): sigma_Q_low = np.array(tab["sigma_add_Q"] - tab["sigma_add_Q_err_minus"]) sigma_Q_high = np.array(tab["sigma_add_Q"] + tab["sigma_add_Q_err_plus"]) @@ -292,7 +294,7 @@ def get_fit_func( degree: int = 2, do_plot: bool = False, high_snr_cut: float = 30.0, -): +) -> Tuple[np.polynomial.Polynomial.fit, plt.Figure]: """Fit an envelope to define leakage sources Args: @@ -313,6 +315,15 @@ def get_fit_func( logger.info(f"{np.sum(hi_snr)} sources with Stokes I SNR above {high_snr_cut=}.") + if len(hi_i_tab) < 100: + logger.critical("Not enough high SNR sources to fit leakage envelope.") + return ( + np.polynomial.Polynomial.fit( + np.array([0, 1]), np.array([0, 0]), deg=0, full=False + ), + plt.figure(), + ) + # Get fractional pol frac_P = np.array(hi_i_tab["fracpol"].value) # Bin sources by separation from tile centre @@ -340,7 +351,6 @@ def get_fit_func( # Plot the fit latexify(columns=2) - figure = plt.figure(facecolor="w") fig = plt.figure(facecolor="w") color = "tab:green" stoke = { @@ -360,17 +370,6 @@ def get_fit_func( zorder=0, rasterized=True, ) - - # is_finite = np.logical_and( - # np.isfinite(hi_i_tab["beamdist"].to(u.deg).value), np.isfinite(frac_P) - # ) - # hist2d( - # np.array(hi_i_tab["beamdist"].to(u.deg).value)[is_finite, np.newaxis], - # np.array(frac_P)[is_finite, np.newaxis], - # bins=(nbins, nbins), - # range=[[0, 5], [0, 0.05]], - # # color=color, - # ) plt.plot(bins_c, meds, alpha=1, c=color, label="Median", linewidth=2) for s, ls in zip((1, 2), ("--", ":")): for r in ("ups", "los"): @@ -406,6 +405,15 @@ def compute_local_rm_flag(good_cat: Table, big_cat: Table) -> Table: logger.info("Computing voronoi bins and finding bad RMs") logger.info(f"Number of available sources: {len(good_cat)}.") + df = good_cat.to_pandas() + df.reset_index(inplace=True) + df.set_index("cat_id", inplace=True) + + df_out = big_cat.to_pandas() + df_out.reset_index(inplace=True) + df_out.set_index("cat_id", inplace=True) + df_out["local_rm_flag"] = False + def sn_func(index, signal=None, noise=None): try: sn = len(np.array(index)) @@ -415,6 +423,7 @@ def sn_func(index, signal=None, noise=None): target_sn = 30 target_bins = 6 + fail = True while target_sn > 1: logger.debug( f"Trying to find Voroni bins with RMs per bin={target_sn}, Number of bins={target_bins}" @@ -464,32 +473,27 @@ def sn_func(index, signal=None, noise=None): fail_msg = "Failed to converge towards a Voronoi binning solution. " logger.error(fail_msg) - raise ValueError(fail_msg) + fail = True - logger.info(f"Found {len(set(bin_number))} bins") - df = good_cat.to_pandas() - df.reset_index(inplace=True) - df.set_index("cat_id", inplace=True) - df["bin_number"] = bin_number - # Use sigma clipping to find outliers + if not fail: + logger.info(f"Found {len(set(bin_number))} bins") + df["bin_number"] = bin_number - def masker(x): - return pd.Series( - sigma_clip(x["rm"], sigma=3, maxiters=None, cenfunc=np.median).mask, - index=x.index, + # Use sigma clipping to find outliers + def masker(x): + return pd.Series( + sigma_clip(x["rm"], sigma=3, maxiters=None, cenfunc=np.median).mask, + index=x.index, + ) + + perc_g = df.groupby("bin_number").apply( + masker, ) + # Put flag into the catalogue + df["local_rm_flag"] = perc_g.reset_index().set_index("cat_id")[0] + df.drop(columns=["bin_number"], inplace=True) + df_out.update(df["local_rm_flag"]) - perc_g = df.groupby("bin_number").apply( - masker, - ) - # Put flag into the catalogue - df["local_rm_flag"] = perc_g.reset_index().set_index("cat_id")[0] - df.drop(columns=["bin_number"], inplace=True) - df_out = big_cat.to_pandas() - df_out.reset_index(inplace=True) - df_out.set_index("cat_id", inplace=True) - df_out["local_rm_flag"] = False - df_out.update(df["local_rm_flag"]) df_out["local_rm_flag"] = df_out["local_rm_flag"].astype(bool) cat_out = RMTable.from_pandas(df_out.reset_index()) cat_out["local_rm_flag"].meta["ucd"] = "meta.code" @@ -507,6 +511,7 @@ def masker(x): return cat_out +@task(name="Add cuts and flags") def cuts_and_flags(cat: RMTable) -> RMTable: """Cut out bad sources, and add flag columns @@ -564,6 +569,7 @@ def cuts_and_flags(cat: RMTable) -> RMTable: return cat_out, fit +@task(name="Get spectral indices") def get_alpha(cat): coefs_str = cat["stokesI_model_coef"] coefs_err_str = cat["stokesI_model_coef_err"] @@ -592,6 +598,7 @@ def get_alpha(cat): ) +@task(name="Get integration times") def get_integration_time(cat, field_col): logger.warn("Will be stripping the trailing field character prefix. ") field_names = [ @@ -620,31 +627,6 @@ def get_integration_time(cat, field_col): return np.array(tints) * u.s -# Stolen from GASKAP pipeline -# Credit to J. Dempsey -# https://github.com/GASKAP/GASKAP-HI-Absorption-Pipeline/ -# https://github.com/GASKAP/GASKAP-HI-Absorption-Pipeline/blob/ -# def add_col_metadata(vo_table, col_name, description, units=None, ucd=None, datatype=None): -# """Add metadata to a VO table column. - -# Args: -# vo_table (vot.): VO Table -# col_name (str): Column name -# description (str): Long description of the column -# units (u.Unit, optional): Unit of column. Defaults to None. -# ucd (str, optional): UCD string. Defaults to None. -# datatype (_type_, optional): _description_. Defaults to None. -# """ -# col = vo_table.get_first_table().get_field_by_id(col_name) -# col.description = description -# if units: -# col.unit = units -# if ucd: -# col.ucd = ucd -# if datatype: -# col.datatype = datatype - - def add_metadata(vo_table: vot.tree.Table, filename: str): """Add metadata to VO Table for CASDA @@ -732,6 +714,7 @@ def fix_blank_units(rmtab: TableLike) -> TableLike: return rmtab +@task(name="Write votable") def write_votable(rmtab: TableLike, outfile: str) -> None: # Replace bad column names fix_columns = { @@ -752,6 +735,7 @@ def write_votable(rmtab: TableLike, outfile: str) -> None: replace_nans(outfile) +@flow(name="Make catalogue") def main( field: str, host: str, @@ -862,10 +846,6 @@ def main( alpha_dict = get_alpha(rmtab) rmtab.add_column(Column(data=alpha_dict["alphas"], name="spectral_index")) rmtab.add_column(Column(data=alpha_dict["alphas_err"], name="spectral_index_err")) - # rmtab.add_column(Column(data=alpha_dict["betas"], name="spectral_curvature")) - # rmtab.add_column( - # Column(data=alpha_dict["betas_err"], name="spectral_curvature_err") - # ) # Add integration time field_col = get_field_db( diff --git a/arrakis/merge_fields.py b/arrakis/merge_fields.py index 0e4786da..675ca6c4 100644 --- a/arrakis/merge_fields.py +++ b/arrakis/merge_fields.py @@ -3,17 +3,16 @@ import os from pprint import pformat from shutil import copyfile -from typing import Dict, List, Union +from typing import Dict, List, Optional import pymongo -from dask import delayed from dask.distributed import Client, LocalCluster +from prefect import flow, task, unmapped from arrakis.linmos import get_yanda, linmos from arrakis.logger import logger from arrakis.utils.database import get_db, test_db from arrakis.utils.io import try_mkdir -from arrakis.utils.pipeline import chunk_dask def make_short_name(name: str) -> str: @@ -23,49 +22,73 @@ def make_short_name(name: str) -> str: return short -@delayed +@task(name="Copy singleton island") def copy_singleton( - beam: dict, vals: dict, merge_name: str, field_dir: str, data_dir: str -) -> pymongo.UpdateOne: - try: - i_file_old = os.path.join(field_dir, vals["i_file"]) - q_file_old = os.path.join(field_dir, vals["q_file_ion"]) - u_file_old = os.path.join(field_dir, vals["u_file_ion"]) - except KeyError: - raise KeyError("Ion files not found. Have you run FRion?") - new_dir = os.path.join(data_dir, beam["Source_ID"]) + beam: dict, field_dict: Dict[str, str], merge_name: str, data_dir: str +) -> List[pymongo.UpdateOne]: + """Copy an island within a single field to the merged field - try_mkdir(new_dir, verbose=False) + Args: + beam (dict): Beam document + field_dict (Dict[str, str]): Field dictionary + merge_name (str): Merged field name + data_dir (str): Output directory - i_file_new = os.path.join(new_dir, os.path.basename(i_file_old)).replace( - ".fits", ".edge.linmos.fits" - ) - q_file_new = os.path.join(new_dir, os.path.basename(q_file_old)).replace( - ".fits", ".edge.linmos.fits" - ) - u_file_new = os.path.join(new_dir, os.path.basename(u_file_old)).replace( - ".fits", ".edge.linmos.fits" - ) + Raises: + KeyError: If ion files not found - for src, dst in zip( - [i_file_old, q_file_old, u_file_old], [i_file_new, q_file_new, u_file_new] - ): - copyfile(src, dst) - src_weight = src.replace(".image.restored.", ".weights.").replace(".ion", "") - dst_weight = dst.replace(".image.restored.", ".weights.").replace(".ion", "") - copyfile(src_weight, dst_weight) - - query = {"Source_ID": beam["Source_ID"]} - newvalues = { - "$set": { - f"beams.{merge_name}.i_file": make_short_name(i_file_new), - f"beams.{merge_name}.q_file": make_short_name(q_file_new), - f"beams.{merge_name}.u_file": make_short_name(u_file_new), - f"beams.{merge_name}.DR1": True, + Returns: + List[pymongo.UpdateOne]: Database updates + """ + updates = [] + for field, vals in beam["beams"].items(): + if field not in field_dict.keys(): + continue + field_dir = field_dict[field] + try: + i_file_old = os.path.join(field_dir, vals["i_file"]) + q_file_old = os.path.join(field_dir, vals["q_file_ion"]) + u_file_old = os.path.join(field_dir, vals["u_file_ion"]) + except KeyError: + raise KeyError("Ion files not found. Have you run FRion?") + new_dir = os.path.join(data_dir, beam["Source_ID"]) + + try_mkdir(new_dir, verbose=False) + + i_file_new = os.path.join(new_dir, os.path.basename(i_file_old)).replace( + ".fits", ".edge.linmos.fits" + ) + q_file_new = os.path.join(new_dir, os.path.basename(q_file_old)).replace( + ".fits", ".edge.linmos.fits" + ) + u_file_new = os.path.join(new_dir, os.path.basename(u_file_old)).replace( + ".fits", ".edge.linmos.fits" + ) + + for src, dst in zip( + [i_file_old, q_file_old, u_file_old], [i_file_new, q_file_new, u_file_new] + ): + copyfile(src, dst) + src_weight = src.replace(".image.restored.", ".weights.").replace( + ".ion", "" + ) + dst_weight = dst.replace(".image.restored.", ".weights.").replace( + ".ion", "" + ) + copyfile(src_weight, dst_weight) + + query = {"Source_ID": beam["Source_ID"]} + newvalues = { + "$set": { + f"beams.{merge_name}.i_file": make_short_name(i_file_new), + f"beams.{merge_name}.q_file": make_short_name(q_file_new), + f"beams.{merge_name}.u_file": make_short_name(u_file_new), + f"beams.{merge_name}.DR1": True, + } } - } - return pymongo.UpdateOne(query, newvalues) + updates.append(pymongo.UpdateOne(query, newvalues)) + return updates def copy_singletons( @@ -73,7 +96,18 @@ def copy_singletons( data_dir: str, beams_col: pymongo.collection.Collection, merge_name: str, -) -> list: +) -> List[pymongo.UpdateOne]: + """Copy islands that don't overlap other fields + + Args: + field_dict (Dict[str, str]): Field dictionary + data_dir (str): Data directory + beams_col (pymongo.collection.Collection): Beams collection + merge_name (str): Merged field name + + Returns: + List[pymongo.UpdateOne]: Database updates + """ # Find all islands with the given fields that DON'T overlap another field query = { "$or": [ @@ -92,18 +126,15 @@ def copy_singletons( big_beams = list( beams_col.find({"Source_ID": {"$in": island_ids}}).sort("Source_ID") ) - updates = [] - for beam in big_beams: - for field, vals in beam["beams"].items(): - if field not in field_dict.keys(): - continue - field_dir = field_dict[field] - update = copy_singleton(beam, vals, merge_name, field_dir, data_dir) - updates.append(update) + updates = copy_singleton.map( + beam=big_beams, + field_dict=unmapped(field_dict), + merge_name=unmapped(merge_name), + data_dir=unmapped(data_dir), + ) return updates -@delayed def genparset( old_ims: list, stokes: str, @@ -139,10 +170,24 @@ def genparset( return parset_file -# @delayed(nout=3) def merge_multiple_field( beam: dict, field_dict: dict, merge_name: str, data_dir: str, image: str -) -> list: +) -> List[pymongo.UpdateOne]: + """Merge an island that overlaps multiple fields + + Args: + beam (dict): Beam document + field_dict (dict): Field dictionary + merge_name (str): Merged field name + data_dir (str): Data directory + image (str): Yandasoft image + + Raises: + KeyError: If ion files not found + + Returns: + List[pymongo.UpdateOne]: Database updates + """ i_files_old = [] q_files_old = [] u_files_old = [] @@ -167,19 +212,32 @@ def merge_multiple_field( for stokes, imlist in zip(["I", "Q", "U"], [i_files_old, q_files_old, u_files_old]): parset_file = genparset(imlist, stokes, new_dir) - update = linmos(parset_file, merge_name, image) + update = linmos.fn(parset_file, merge_name, image) updates.append(update) return updates +@task(name="Merge multiple islands") def merge_multiple_fields( field_dict: Dict[str, str], data_dir: str, beams_col: pymongo.collection.Collection, merge_name: str, image: str, -) -> list: +) -> List[pymongo.UpdateOne]: + """Merge multiple islands that overlap multiple fields + + Args: + field_dict (Dict[str, str]): Field dictionary + data_dir (str): Data directory + beams_col (pymongo.collection.Collection): Beams collection + merge_name (str): Merged field name + image (str): Yandasoft image + + Returns: + List[pymongo.UpdateOne]: Database updates + """ # Find all islands with the given fields that overlap another field query = { "$or": [ @@ -199,14 +257,18 @@ def merge_multiple_fields( beams_col.find({"Source_ID": {"$in": island_ids}}).sort("Source_ID") ) - updates = [] - for beam in big_beams: - update = merge_multiple_field(beam, field_dict, merge_name, data_dir, image) - updates.extend(update) + updates = merge_multiple_field.map( + beam=big_beams, + field_dict=unmapped(field_dict), + merge_name=unmapped(merge_name), + data_dir=unmapped(data_dir), + image=unmapped(image), + ) return updates +@flow(name="Merge fields") def main( fields: List[str], field_dirs: List[str], @@ -214,10 +276,9 @@ def main( output_dir: str, host: str, epoch: int, - username: Union[str, None] = None, - password: Union[str, None] = None, + username: Optional[str] = None, + password: Optional[str] = None, yanda="1.3.0", - verbose: bool = True, ) -> str: logger.debug(f"{fields=}") @@ -257,21 +318,8 @@ def main( image=image, ) - singleton_futures = chunk_dask( - outputs=singleton_updates, - task_name="singleton islands", - progress_text="Copying singleton islands", - verbose=verbose, - ) - singleton_comp = [f.compute() for f in singleton_futures] - - multiple_futures = chunk_dask( - outputs=mutilple_updates, - task_name="overlapping islands", - progress_text="Running LINMOS on overlapping islands", - verbose=verbose, - ) - multiple_comp = [f.compute() for f in multiple_futures] + singleton_comp = [f.result() for f in singleton_updates] + multiple_comp = [f.result() for f in mutilple_updates] for m in multiple_comp: m._doc["$set"].update({f"beams.{merge_name}.DR1": True}) @@ -345,10 +393,6 @@ def cli(): help="Epoch of observation.", ) - parser.add_argument( - "-v", dest="verbose", action="store_true", help="Verbose output [False]." - ) - parser.add_argument( "--username", type=str, default=None, help="Username of mongodb." ) @@ -358,10 +402,6 @@ def cli(): ) args = parser.parse_args() - - cluster = LocalCluster() - client = Client(cluster) - verbose = args.verbose test_db( host=args.host, username=args.username, password=args.password, verbose=verbose @@ -377,12 +417,8 @@ def cli(): username=args.username, password=args.password, yanda=args.yanda, - verbose=verbose, ) - client.close() - cluster.close() - if __name__ == "__main__": cli() diff --git a/arrakis/process_region.py b/arrakis/process_region.py index aacea0f5..768620ad 100644 --- a/arrakis/process_region.py +++ b/arrakis/process_region.py @@ -6,22 +6,16 @@ import pkg_resources import yaml from astropy.time import Time -from dask.distributed import Client -from dask_jobqueue import SLURMCluster -from dask_mpi import initialize -from prefect import flow, task -from prefect_dask import DaskTaskRunner +from prefect import flow -from arrakis import merge_fields, process_spice +from arrakis import makecat, merge_fields, process_spice, rmclean_oncuts, rmsynth_oncuts from arrakis.logger import logger from arrakis.utils.database import test_db -from arrakis.utils.pipeline import logo_str, port_forward - -merge_task = task(merge_fields.main, name="Merge fields") +from arrakis.utils.pipeline import logo_str @flow -def process_merge(args, host: str, inter_dir: str) -> None: +def process_merge(args, host: str, inter_dir: str, task_runner) -> None: """Workflow to merge spectra from overlapping fields together Args: @@ -31,7 +25,7 @@ def process_merge(args, host: str, inter_dir: str) -> None: """ previous_future = None previous_future = ( - merge_task.submit( + merge_fields.with_options(task_runner=task_runner)( fields=args.fields, field_dirs=args.datadirs, merge_name=args.merge_name, @@ -41,14 +35,13 @@ def process_merge(args, host: str, inter_dir: str) -> None: username=args.username, password=args.password, yanda=args.yanda, - verbose=args.verbose, ) if not args.skip_merge else previous_future ) previous_future = ( - process_spice.rmsynth_task.submit( + rmsynth_oncuts.main.with_options(task_runner=task_runner)( field=args.merge_name, outdir=inter_dir, host=host, @@ -58,7 +51,7 @@ def process_merge(args, host: str, inter_dir: str) -> None: dimension=args.dimension, verbose=args.verbose, database=args.database, - validate=args.validate, + do_validate=args.validate, limit=args.limit, savePlots=args.savePlots, weightType=args.weightType, @@ -77,14 +70,13 @@ def process_merge(args, host: str, inter_dir: str) -> None: tt1=args.tt1, ion=False, do_own_fit=args.do_own_fit, - wait_for=[previous_future], ) if not args.skip_rmsynth else previous_future ) previous_future = ( - process_spice.rmclean_task.submit( + rmclean_oncuts.main.with_options(task_runner=task_runner)( field=args.merge_name, outdir=inter_dir, host=host, @@ -92,9 +84,7 @@ def process_merge(args, host: str, inter_dir: str) -> None: username=args.username, password=args.password, dimension=args.dimension, - verbose=args.verbose, database=args.database, - validate=args.validate, limit=args.limit, cutoff=args.cutoff, maxIter=args.maxIter, @@ -102,14 +92,13 @@ def process_merge(args, host: str, inter_dir: str) -> None: window=args.window, showPlots=args.showPlots, rm_verbose=args.rm_verbose, - wait_for=[previous_future], ) if not args.skip_rmclean else previous_future ) previous_future = ( - process_spice.cat_task.submit( + makecat.main.with_options(task_runner=task_runner)( field=args.merge_name, host=host, epoch=args.epoch, @@ -117,7 +106,6 @@ def process_merge(args, host: str, inter_dir: str) -> None: password=args.password, verbose=args.verbose, outfile=args.outfile, - wait_for=[previous_future], ) if not args.skip_cat else previous_future @@ -130,8 +118,6 @@ def main(args: configargparse.Namespace) -> None: Args: args (configargparse.Namespace): Command line arguments. """ - host = args.host - if args.dask_config is None: config_dir = pkg_resources.resource_filename("arrakis", "configs") args.dask_config = os.path.join(config_dir, "default.yaml") @@ -139,33 +125,6 @@ def main(args: configargparse.Namespace) -> None: if args.outfile is None: args.outfile = f"{args.merge_name}.pipe.test.fits" - # Following https://github.com/dask/dask-jobqueue/issues/499 - with open(args.dask_config) as f: - config = yaml.safe_load(f) - - config.update( - { - # 'scheduler_options': { - # "dashboard_address": f":{args.port}" - # }, - "log_directory": f"{args.merge_name}_{Time.now().fits}_spice_logs/" - } - ) - if args.use_mpi: - initialize( - interface=config["interface"], - local_directory=config["local_directory"], - nthreads=config["cores"] / config["processes"], - ) - client = Client() - else: - cluster = SLURMCluster( - **config, - ) - logger.debug(f"Submitted scripts will look like: \n {cluster.job_script()}") - - client = Client(cluster) - test_db( host=args.host, username=args.username, @@ -178,25 +137,15 @@ def main(args: configargparse.Namespace) -> None: with open(args_yaml_f, "w") as f: f.write(args_yaml) - port = client.scheduler_info()["services"]["dashboard"] - - # Forward ports - if args.port_forward is not None: - for p in args.port_forward: - port_forward(port, p) - - # Prin out Dask client info - logger.info(client.scheduler_info()["services"]) - - dask_runner = DaskTaskRunner(address=client.scheduler.address) + dask_runner = process_spice.create_dask_runner( + dask_config=args.dask_config, + ) inter_dir = os.path.join(os.path.abspath(args.output_dir), args.merge_name) process_merge.with_options( name=f"SPICE-RACS: {args.merge_name}", task_runner=dask_runner - )(args, host, inter_dir) - - client.close() + )(args, args.host, inter_dir, dask_runner) def cli(): diff --git a/arrakis/process_spice.py b/arrakis/process_spice.py index 81db0567..d9ebe382 100644 --- a/arrakis/process_spice.py +++ b/arrakis/process_spice.py @@ -3,17 +3,14 @@ import logging import os from pathlib import Path -from typing import Any import configargparse import pkg_resources import yaml from astropy.time import Time -from dask.distributed import Client -from dask_jobqueue import SLURMCluster -from dask_mpi import initialize -from prefect import flow, task -from prefect_dask import DaskTaskRunner, get_dask_client +from prefect import flow +from prefect.task_runners import BaseTaskRunner +from prefect_dask import DaskTaskRunner from arrakis import ( cleanup, @@ -27,173 +24,158 @@ ) from arrakis.logger import logger from arrakis.utils.database import test_db -from arrakis.utils.pipeline import logo_str, performance_report_prefect - -# Defining tasks -cut_task = task(cutout.cutout_islands, name="Cutout") -linmos_task = task(linmos.main, name="LINMOS") -frion_task = task(frion.main, name="FRion") -cleanup_task = task(cleanup.main, name="Clean up") -rmsynth_task = task(rmsynth_oncuts.main, name="RM Synthesis") -rmclean_task = task(rmclean_oncuts.main, name="RM-CLEAN") -cat_task = task(makecat.main, name="Catalogue") +from arrakis.utils.pipeline import logo_str @flow(name="Combining+Synthesis on Arrakis") -def process_spice(args, host: str) -> None: +def process_spice(args, host: str, task_runner: BaseTaskRunner) -> None: """Workflow to process the SPIRCE-RACS data Args: args (configargparse.Namespace): Configuration parameters for this run host (str): Host address of the mongoDB. """ - # TODO: Fix the type assigned to args. The `configargparse.Namespace` was causing issues - # with the pydantic validation used by prefect / flow. - - with get_dask_client(): - previous_future = None - previous_future = ( - cut_task.submit( - field=args.field, - directory=args.outdir, - host=host, - epoch=args.epoch, - username=args.username, - password=args.password, - pad=args.pad, - stokeslist=["I", "Q", "U"], - verbose_worker=args.verbose_worker, - dryrun=args.dryrun, - ) - if not args.skip_cutout - else previous_future + outfile = f"{args.field}.pipe.test.fits" if args.outfile is None else args.outfile + + previous_future = None + previous_future = ( + cutout.cutout_islands.with_options( + task_runner=task_runner, + )( + field=args.field, + directory=str(args.outdir), + host=host, + epoch=args.epoch, + username=args.username, + password=args.password, + pad=args.pad, + stokeslist=["I", "Q", "U"], + verbose_worker=args.verbose_worker, + dryrun=args.dryrun, + limit=args.limit, ) - - previous_future = ( - linmos_task.submit( - field=args.field, - datadir=Path(args.outdir), - host=host, - epoch=args.epoch, - holofile=Path(args.holofile), - username=args.username, - password=args.password, - yanda=args.yanda, - yanda_img=args.yanda_image, - stokeslist=["I", "Q", "U"], - verbose=True, - wait_for=[previous_future], - ) - if not args.skip_linmos - else previous_future + if not args.skip_cutout + else previous_future + ) + + previous_future = ( + linmos.main.with_options( + task_runner=task_runner, + )( + field=args.field, + datadir=Path(args.outdir), + host=host, + epoch=args.epoch, + holofile=Path(args.holofile), + username=args.username, + password=args.password, + yanda=args.yanda, + yanda_img=args.yanda_image, + stokeslist=["I", "Q", "U"], + limit=args.limit, ) + if not args.skip_linmos + else previous_future + ) - previous_future = ( - cleanup_task.submit( - datadir=args.outdir, - stokeslist=["I", "Q", "U"], - verbose=True, - wait_for=[previous_future], - ) - if not args.skip_cleanup - else previous_future + previous_future = ( + cleanup.main.with_options(task_runner=task_runner)( + datadir=args.outdir, + stokeslist=["I", "Q", "U"], ) - - previous_future = ( - frion_task.submit( - field=args.field, - outdir=args.outdir, - host=host, - epoch=args.epoch, - username=args.username, - password=args.password, - database=args.database, - verbose=args.verbose, - ionex_server=args.ionex_server, - ionex_proxy_server=args.ionex_proxy_server, - ionex_formatter=args.ionex_formatter, - ionex_predownload=args.ionex_predownload, - wait_for=[previous_future], - ) - if not args.skip_frion - else previous_future + if not args.skip_cleanup + else previous_future + ) + + previous_future = ( + frion.main.with_options(task_runner=task_runner)( + field=args.field, + outdir=args.outdir, + host=host, + epoch=args.epoch, + username=args.username, + password=args.password, + database=args.database, + ionex_server=args.ionex_server, + ionex_proxy_server=args.ionex_proxy_server, + ionex_formatter=args.ionex_formatter, + ionex_predownload=args.ionex_predownload, + limit=args.limit, ) - - previous_future = ( - rmsynth_task.submit( - field=args.field, - outdir=args.outdir, - host=host, - epoch=args.epoch, - username=args.username, - password=args.password, - dimension=args.dimension, - verbose=args.verbose, - database=args.database, - validate=args.validate, - limit=args.limit, - savePlots=args.savePlots, - weightType=args.weightType, - fitRMSF=args.fitRMSF, - phiMax_radm2=args.phiMax_radm2, - dPhi_radm2=args.dPhi_radm2, - nSamples=args.nSamples, - polyOrd=args.polyOrd, - noStokesI=args.noStokesI, - showPlots=args.showPlots, - not_RMSF=args.not_RMSF, - rm_verbose=args.rm_verbose, - debug=args.debug, - fit_function=args.fit_function, - tt0=args.tt0, - tt1=args.tt1, - ion=True, - do_own_fit=args.do_own_fit, - wait_for=[previous_future], - ) - if not args.skip_rmsynth - else previous_future + if not args.skip_frion + else previous_future + ) + + previous_future = ( + rmsynth_oncuts.main.with_options(task_runner=task_runner)( + field=args.field, + outdir=args.outdir, + host=host, + epoch=args.epoch, + username=args.username, + password=args.password, + dimension=args.dimension, + verbose=args.verbose, + database=args.database, + do_validate=args.validate, + limit=args.limit, + savePlots=args.savePlots, + weightType=args.weightType, + fitRMSF=args.fitRMSF, + phiMax_radm2=args.phiMax_radm2, + dPhi_radm2=args.dPhi_radm2, + nSamples=args.nSamples, + polyOrd=args.polyOrd, + noStokesI=args.noStokesI, + showPlots=args.showPlots, + not_RMSF=args.not_RMSF, + rm_verbose=args.rm_verbose, + debug=args.debug, + fit_function=args.fit_function, + tt0=args.tt0, + tt1=args.tt1, + ion=True, + do_own_fit=args.do_own_fit, ) - - previous_future = ( - rmclean_task.submit( - field=args.field, - outdir=args.outdir, - host=host, - epoch=args.epoch, - username=args.username, - password=args.password, - dimension=args.dimension, - verbose=args.verbose, - database=args.database, - validate=args.validate, - limit=args.limit, - cutoff=args.cutoff, - maxIter=args.maxIter, - gain=args.gain, - window=args.window, - showPlots=args.showPlots, - rm_verbose=args.rm_verbose, - wait_for=[previous_future], - ) - if not args.skip_rmclean - else previous_future + if not args.skip_rmsynth + else previous_future + ) + + previous_future = ( + rmclean_oncuts.main.with_options(task_runner=task_runner)( + field=args.field, + outdir=args.outdir, + host=host, + epoch=args.epoch, + username=args.username, + password=args.password, + dimension=args.dimension, + database=args.database, + limit=args.limit, + cutoff=args.cutoff, + maxIter=args.maxIter, + gain=args.gain, + window=args.window, + showPlots=args.showPlots, + rm_verbose=args.rm_verbose, ) - - previous_future = ( - cat_task.submit( - field=args.field, - host=host, - epoch=args.epoch, - username=args.username, - password=args.password, - verbose=args.verbose, - outfile=args.outfile, - wait_for=[previous_future], - ) - if not args.skip_cat - else previous_future + if not args.skip_rmclean + else previous_future + ) + + previous_future = ( + makecat.main.with_options(task_runner=task_runner)( + field=args.field, + host=host, + epoch=args.epoch, + username=args.username, + password=args.password, + verbose=args.verbose, + outfile=outfile, ) + if not args.skip_cat + else previous_future + ) def save_args(args: configargparse.Namespace) -> Path: @@ -215,80 +197,46 @@ def save_args(args: configargparse.Namespace) -> Path: return Path(args_yaml_f) -def create_client( +def create_dask_runner( dask_config: str, - use_mpi: bool, - port_forward: Any, - minimum: int = 1, - maximum: int = 38, - mode: str = "adapt", -) -> Client: - logger.info("Creating a Client") + overload: bool = False, +) -> DaskTaskRunner: + """Create a DaskTaskRunner + + Args: + dask_config (str): Configuraiton file for the DaskTaskRunner + overload (bool, optional): Overload the options for threadded work. Defaults to False. + + Returns: + DaskTaskRunner: The prefect DaskTaskRunner instance + """ + logger.setLevel(logging.INFO) + logger.info("Creating a Dask Task Runner.") if dask_config is None: config_dir = pkg_resources.resource_filename("arrakis", "configs") dask_config = f"{config_dir}/default.yaml" - # Following https://github.com/dask/dask-jobqueue/issues/499 with open(dask_config) as f: logger.info(f"Loading {dask_config}") - config = yaml.safe_load(f) - - logger.info("Overwriting config attributes.") - config["job_cpu"] = config["cores"] - config["cores"] = 1 - config["processes"] = 1 - - # config.update( - # { - # "log_directory": f"{field}_{Time.now().fits}_spice_logs/" - # } - # ) - if use_mpi: - initialize( - interface=config["interface"], - local_directory=config["local_directory"], - nthreads=config["cores"] / config["processes"], - ) - client = Client() - else: - # TODO: load the cluster type and initialise it from field specified in the loaded config - cluster = SLURMCluster( - **config, - ) - logger.debug(f"Submitted scripts will look like: \n {cluster.job_script()}") - - if mode == "adapt": - cluster.adapt(minimum=minimum, maximum=maximum) - elif mode == "scale": - cluster.scale(maximum) - # cluster.scale(36) + yaml_config: dict = yaml.safe_load(f) - # cluster = LocalCluster(n_workers=10, processes=True, threads_per_worker=1, local_directory="/dev/shm",dashboard_address=f":{args.port}") - client = Client(cluster) - port = client.scheduler_info()["services"]["dashboard"] + cluster_class_str = yaml_config.get("cluster_class", "distributed.LocalCluster") + cluster_kwargs = yaml_config.get("cluster_kwargs", {}) + adapt_kwargs = yaml_config.get("adapt_kwargs", {}) - # Forward ports - if port_forward is not None: - for p in port_forward: - port_forward(port, p) + if overload: + logger.info("Overwriting config attributes.") + cluster_kwargs["job_cpu"] = cluster_kwargs["cores"] + cluster_kwargs["cores"] = 1 + cluster_kwargs["processes"] = 1 - # Prin out Dask client info - logger.info(client.scheduler_info()["services"]) + config = { + "cluster_class": cluster_class_str, + "cluster_kwargs": cluster_kwargs, + "adapt_kwargs": adapt_kwargs, + } - return client - - -def create_dask_runner(*args, **kwargs) -> DaskTaskRunner: - """Internally creates a Client object via `create_client`, - and then initialises a DaskTaskRunner. - - Returns: - DaskTaskRunner: A Prefect dask based task runner - """ - client = create_client(*args, **kwargs) - - logger.info("Creating DaskTaskRunner") - return DaskTaskRunner(address=client.scheduler.address), client + return DaskTaskRunner(**config) def main(args: configargparse.Namespace) -> None: @@ -310,59 +258,52 @@ def main(args: configargparse.Namespace) -> None: password=args.password, ) - if args.outfile is None: - outfile = f"{args.field}.pipe.test.fits" - if not args.skip_imager: # This is the client for the imager component of the arrakis # pipeline. - dask_runner, client = create_dask_runner( + dask_runner = create_dask_runner( dask_config=args.imager_dask_config, - use_mpi=args.use_mpi, - port_forward=args.port_forward, - minimum=1, - maximum=38, + overload=True, ) logger.info("Obtained DaskTaskRunner, executing the imager workflow. ") - with performance_report_prefect( - f"arrakis-imaging-{args.field}-report-{Time.now().fits}.html" - ): - imager.main.with_options( - name=f"Arrakis Imaging -- {args.field}", task_runner=dask_runner - )( - msdir=args.msdir, - out_dir=args.outdir, - cutoff=args.psf_cutoff, - robust=args.robust, - pols=args.pols, - nchan=args.nchan, - local_rms=args.local_rms, - local_rms_window=args.local_rms_window, - size=args.size, - scale=args.scale, - mgain=args.mgain, - niter=args.niter, - nmiter=args.nmiter, - auto_mask=args.auto_mask, - force_mask_rounds=args.force_mask_rounds, - auto_threshold=args.auto_threshold, - minuv=args.minuv, - purge=args.purge, - taper=args.taper, - parallel_deconvolution=args.parallel, - gridder=args.gridder, - wsclean_path=Path(args.local_wsclean) - if args.local_wsclean - else args.hosted_wsclean, - multiscale=args.multiscale, - multiscale_scale_bias=args.multiscale_scale_bias, - absmem=args.absmem, - ms_glob_pattern=args.ms_glob_pattern, - data_column=args.data_column, - ) + imager.main.with_options( + name=f"Arrakis Imaging -- {args.field}", task_runner=dask_runner + )( + msdir=args.msdir, + out_dir=args.outdir, + cutoff=args.psf_cutoff, + robust=args.robust, + pols=args.pols, + nchan=args.nchan, + local_rms=args.local_rms, + local_rms_window=args.local_rms_window, + size=args.size, + scale=args.scale, + mgain=args.mgain, + niter=args.niter, + nmiter=args.nmiter, + auto_mask=args.auto_mask, + force_mask_rounds=args.force_mask_rounds, + auto_threshold=args.auto_threshold, + minuv=args.minuv, + purge=args.purge, + taper=args.taper, + parallel_deconvolution=args.parallel, + gridder=args.gridder, + wsclean_path=Path(args.local_wsclean) + if args.local_wsclean + else args.hosted_wsclean, + multiscale=args.multiscale, + multiscale_scale_bias=args.multiscale_scale_bias, + absmem=args.absmem, + ms_glob_pattern=args.ms_glob_pattern, + data_column=args.data_column, + ) + client = dask_runner._client + if client is not None: client.close() - del dask_runner + del dask_runner else: logger.warn("Skipping the image creation step. ") @@ -371,25 +312,14 @@ def main(args: configargparse.Namespace) -> None: return # This is the client and pipeline for the RM extraction - dask_runner_2, client = create_dask_runner( + dask_runner_2 = create_dask_runner( dask_config=args.dask_config, - use_mpi=args.use_mpi, - port_forward=args.port_forward, - minimum=64, - maximum=64, ) # Define flow - with performance_report_prefect( - f"arrakis-synthesis-{args.field}-report-{Time.now().fits}.html" - ): - process_spice.with_options( - name=f"Arrakis Synthesis -- {args.field}", task_runner=dask_runner_2 - )(args, host) - - # TODO: Access the client via the `dask_runner`. Perhaps a - # way to do this is to extend the DaskTaskRunner's - # destructor and have it create it then. + process_spice.with_options( + name=f"Arrakis Synthesis -- {args.field}", task_runner=dask_runner_2 + )(args, host, dask_runner_2) def cli(): @@ -442,24 +372,6 @@ def cli(): "--password", type=str, default=None, help="Password of mongodb." ) - # parser.add_argument( - # '--port', - # type=int, - # default=9999, - # help="Port to run Dask dashboard on." - # ) - parser.add_argument( - "--use_mpi", - action="store_true", - help="Use Dask-mpi to parallelise -- must use srun/mpirun to assign resources.", - ) - parser.add_argument( - "--port_forward", - default=None, - help="Platform to fowards dask port [None].", - nargs="+", - ) - parser.add_argument( "--dask_config", type=str, @@ -722,8 +634,7 @@ def cli(): "--outfile", default=None, type=str, help="File to save table to [None]." ) args = parser.parse_args() - if not args.use_mpi: - parser.print_values() + parser.print_values() verbose = args.verbose if verbose: diff --git a/arrakis/rmclean_oncuts.py b/arrakis/rmclean_oncuts.py index f8502177..0f8e5173 100644 --- a/arrakis/rmclean_oncuts.py +++ b/arrakis/rmclean_oncuts.py @@ -7,23 +7,22 @@ from pathlib import Path from pprint import pformat from shutil import copyfile -from typing import Union +from typing import Optional import matplotlib.pyplot as plt import numpy as np import pymongo -from dask import delayed from dask.distributed import Client, LocalCluster +from prefect import flow, task, unmapped from RMtools_1D import do_RMclean_1D from RMtools_3D import do_RMclean_3D -from tqdm import tqdm from arrakis.logger import logger from arrakis.utils.database import get_db, test_db -from arrakis.utils.pipeline import chunk_dask, logo_str +from arrakis.utils.pipeline import logo_str -@delayed +@task(name="1D RM-CLEAN") def rmclean1d( comp: dict, outdir: str, @@ -135,7 +134,7 @@ def rmclean1d( return pymongo.UpdateOne(myquery, newvalues) -@delayed +@task(name="3D RM-CLEAN") def rmclean3d( island: dict, outdir: str, @@ -203,19 +202,18 @@ def rmclean3d( return pymongo.UpdateOne(myquery, newvalues) +@flow(name="RM-CLEAN on cutouts") def main( field: str, outdir: Path, host: str, epoch: int, - username: Union[str, None] = None, - password: Union[str, None] = None, + username: Optional[str] = None, + password: Optional[str] = None, dimension="1d", - verbose=True, database=False, savePlots=True, - validate=False, - limit: Union[int, None] = None, + limit: Optional[int] = None, cutoff: float = -3, maxIter=10000, gain=0.1, @@ -296,52 +294,37 @@ def main( n_island = count # component_ids = component_ids[:count] - outputs = [] if dimension == "1d": logger.info(f"Running RM-CLEAN on {n_comp} components") - for i, comp in enumerate(tqdm(components, total=n_comp)): - if i > n_comp + 1: - break - else: - output = rmclean1d( - comp=comp, - outdir=outdir, - cutoff=cutoff, - maxIter=maxIter, - gain=gain, - showPlots=showPlots, - savePlots=savePlots, - rm_verbose=rm_verbose, - window=window, - ) - outputs.append(output) - + outputs = rmclean1d.map( + comp=components, + outdir=unmapped(outdir), + cutoff=unmapped(cutoff), + maxIter=unmapped(maxIter), + gain=unmapped(gain), + showPlots=unmapped(showPlots), + savePlots=unmapped(savePlots), + rm_verbose=unmapped(rm_verbose), + window=unmapped(window), + ) elif dimension == "3d": logger.info(f"Running RM-CLEAN on {n_island} islands") - for i, island in enumerate(islands): - if i > n_island + 1: - break - else: - output = rmclean3d( - island=island, - outdir=outdir, - cutoff=cutoff, - maxIter=maxIter, - gain=gain, - rm_verbose=rm_verbose, - ) - outputs.append(output) - futures = chunk_dask( - outputs=outputs, - task_name="RM-CLEAN", - progress_text="Running RM-CLEAN", - verbose=verbose, - ) + outputs = rmclean3d.map( + island=islands, + outdir=unmapped(outdir), + cutoff=unmapped(cutoff), + maxIter=unmapped(maxIter), + gain=unmapped(gain), + rm_verbose=unmapped(rm_verbose), + ) + + else: + raise ValueError(f"Dimension {dimension} not supported.") if database: logger.info("Updating database...") - updates = [f.compute() for f in futures] + updates = [f.result() for f in outputs] if dimension == "1d": db_res = comp_col.bulk_write(updates, ordered=False) logger.info(pformat(db_res.bulk_api_result)) @@ -425,12 +408,6 @@ def cli(): parser.add_argument( "-sp", "--savePlots", action="store_true", help="save the plots [False]." ) - parser.add_argument( - "--validate", - dest="validate", - action="store_true", - help="Run on Stokes I [False].", - ) parser.add_argument( "--limit", @@ -474,9 +451,6 @@ def cli(): args = parser.parse_args() - cluster = LocalCluster(n_workers=20) - client = Client(cluster) - verbose = args.verbose rmv = args.rm_verbose host = args.host @@ -499,10 +473,8 @@ def cli(): username=args.username, password=args.password, dimension=args.dimension, - verbose=verbose, database=args.database, savePlots=args.savePlots, - validate=args.validate, limit=args.limit, cutoff=args.cutoff, maxIter=args.maxIter, @@ -512,9 +484,6 @@ def cli(): rm_verbose=args.rm_verbose, ) - client.close() - cluster.close() - if __name__ == "__main__": cli() diff --git a/arrakis/rmsynth_oncuts.py b/arrakis/rmsynth_oncuts.py index 12de66f5..e2023d98 100644 --- a/arrakis/rmsynth_oncuts.py +++ b/arrakis/rmsynth_oncuts.py @@ -23,21 +23,20 @@ from astropy.stats import mad_std, sigma_clip from astropy.wcs import WCS from astropy.wcs.utils import proj_plane_pixel_scales -from dask import delayed from dask.distributed import Client, LocalCluster +from prefect import flow, task, unmapped from radio_beam import Beam from RMtools_1D import do_RMsynth_1D from RMtools_3D import do_RMsynth_3D from RMutils.util_misc import create_frac_spectra from scipy.stats import norm -from tqdm import tqdm from arrakis.logger import logger from arrakis.utils.database import get_db, test_db from arrakis.utils.fitsutils import getfreq from arrakis.utils.fitting import fit_pl from arrakis.utils.io import try_mkdir -from arrakis.utils.pipeline import chunk_dask, logo_str +from arrakis.utils.pipeline import logo_str logger.setLevel(logging.INFO) @@ -85,10 +84,10 @@ class StokesIFitResult(Struct): """The dictionary of the fit results""" -@delayed +@task(name="3D RM-synthesis") def rmsynthoncut3d( island_id: str, - beam: dict, + beams: pd.DataFrame, outdir: str, freq: np.ndarray, field: str, @@ -100,7 +99,7 @@ def rmsynthoncut3d( not_RMSF: bool = False, rm_verbose: bool = False, ion: bool = False, -): +) -> pymongo.UpdateOne: """3D RM-synthesis Args: @@ -117,7 +116,7 @@ def rmsynthoncut3d( not_RMSF (bool, optional): Skip calculation of RMSF. Defaults to False. rm_verbose (bool, optional): Verbose RMsynth. Defaults to False. """ - + beam = dict(beams.loc[island_id]) iname = island_id ifile = os.path.join(outdir, beam["beams"][field]["i_file"]) @@ -456,10 +455,10 @@ def update_rmtools_dict( mDict[f"fit_flag_{key}"] = val -@delayed +@task(name="1D RM-synthesis") def rmsynthoncut1d( - comp: dict, - beam: dict, + comp_tuple: Tuple[str, pd.Series], + beams: pd.DataFrame, outdir: str, freq: np.ndarray, field: str, @@ -502,6 +501,8 @@ def rmsynthoncut1d( rm_verbose (bool, optional): Verbose RMsynth. Defaults to False. """ logger.setLevel(logging.INFO) + comp = comp_tuple[1] + beam = dict(beams.loc[comp["Source_ID"]]) iname = comp["Source_ID"] cname = comp["Gaussian_ID"] @@ -693,7 +694,6 @@ def rmsynthoncut1d( return pymongo.UpdateOne(myquery, newvalues) -@delayed def rmsynthoncut_i( comp_id: str, outdir: str, @@ -830,6 +830,7 @@ def rmsynthoncut_i( do_RMsynth_1D.saveOutput(mDict, aDict, prefix, verbose=verbose) +@flow(name="RMsynth on cutouts") def main( field: str, outdir: Path, @@ -840,7 +841,7 @@ def main( dimension: str = "1d", verbose: bool = True, database: bool = False, - validate: bool = False, + do_validate: bool = False, limit: Union[int, None] = None, savePlots: bool = False, weightType: str = "variance", @@ -914,6 +915,7 @@ def main( n_island = limit island_ids = island_ids[:limit] component_ids = component_ids[:limit] + components = components.iloc[:limit] # Make frequency file freq, freqfile = getfreq( @@ -923,13 +925,11 @@ def main( ) freq = np.array(freq) - outputs = [] - - if validate: + if do_validate: logger.info(f"Running RMsynth on {n_comp} components") # We don't run this in parallel! for i, comp_id in enumerate(component_ids): - output = rmsynthoncut_i( + _ = rmsynthoncut_i( comp_id=comp_id, outdir=outdir, freq=freq, @@ -943,84 +943,56 @@ def main( verbose=verbose, rm_verbose=rm_verbose, ) - output.compute() elif dimension == "1d": logger.info(f"Running RMsynth on {n_comp} components") - for i, (_, comp) in tqdm( - enumerate(components.iterrows()), - total=n_comp, - disable=(not verbose), - desc="Constructing 1D RMsynth jobs", - ): - if i > n_comp + 1: - break - else: - beam = dict(beams.loc[comp["Source_ID"]]) - output = rmsynthoncut1d( - comp=comp, - beam=beam, - outdir=outdir, - freq=freq, - field=field, - polyOrd=polyOrd, - phiMax_radm2=phiMax_radm2, - dPhi_radm2=dPhi_radm2, - nSamples=nSamples, - weightType=weightType, - fitRMSF=fitRMSF, - noStokesI=noStokesI, - showPlots=showPlots, - savePlots=savePlots, - debug=debug, - rm_verbose=rm_verbose, - fit_function=fit_function, - tt0=tt0, - tt1=tt1, - ion=ion, - do_own_fit=do_own_fit, - ) - outputs.append(output) + outputs = rmsynthoncut1d.map( + comp_tuple=components.iterrows(), + beams=unmapped(beams), + outdir=unmapped(outdir), + freq=unmapped(freq), + field=unmapped(field), + polyOrd=unmapped(polyOrd), + phiMax_radm2=unmapped(phiMax_radm2), + dPhi_radm2=unmapped(dPhi_radm2), + nSamples=unmapped(nSamples), + weightType=unmapped(weightType), + fitRMSF=unmapped(fitRMSF), + noStokesI=unmapped(noStokesI), + showPlots=unmapped(showPlots), + savePlots=unmapped(savePlots), + debug=unmapped(debug), + rm_verbose=unmapped(rm_verbose), + fit_function=unmapped(fit_function), + tt0=unmapped(tt0), + tt1=unmapped(tt1), + ion=unmapped(ion), + do_own_fit=unmapped(do_own_fit), + ) elif dimension == "3d": logger.info(f"Running RMsynth on {n_island} islands") - - for i, island_id in enumerate(island_ids): - if i > n_island + 1: - break - else: - beam = dict(beams.loc[island_id]) - output = rmsynthoncut3d( - island_id=island_id, - beam=beam, - outdir=outdir, - freq=freq, - field=field, - phiMax_radm2=phiMax_radm2, - dPhi_radm2=dPhi_radm2, - nSamples=nSamples, - weightType=weightType, - fitRMSF=fitRMSF, - not_RMSF=not_RMSF, - rm_verbose=rm_verbose, - ion=ion, - ) - outputs.append(output) + outputs = rmsynthoncut3d.map( + island_id=island_ids, + beams=unmapped(beams), + outdir=unmapped(outdir), + freq=unmapped(freq), + field=unmapped(field), + phiMax_radm2=unmapped(phiMax_radm2), + dPhi_radm2=unmapped(dPhi_radm2), + nSamples=unmapped(nSamples), + weightType=unmapped(weightType), + fitRMSF=unmapped(fitRMSF), + not_RMSF=unmapped(not_RMSF), + rm_verbose=unmapped(rm_verbose), + ion=unmapped(ion), + ) else: raise ValueError("An incorrect RMSynth mode has been configured. ") - futures = chunk_dask( - outputs=outputs, - task_name="RMsynth", - progress_text="Running RMsynth", - verbose=verbose, - ) - if database: logger.info("Updating database...") - updates = [f.compute() for f in futures] - # Remove None values - updates = [u for u in updates if u is not None] + updates = [u.result() for u in outputs if u.result() is not None] logger.info("Sending updates to database...") if dimension == "1d": db_res = comp_col.bulk_write(updates, ordered=False) @@ -1233,13 +1205,6 @@ def cli(): elif verbose: logger.setLevel(logger.INFO) - cluster = LocalCluster( - # n_workers=12, processes=True, threads_per_worker=1, - local_directory="/dev/shm" - ) - client = Client(cluster) - logger.debug(client) - test_db( host=args.host, username=args.username, password=args.password, verbose=verbose ) @@ -1253,7 +1218,7 @@ def cli(): dimension=args.dimension, verbose=verbose, database=args.database, - validate=args.validate, + do_validate=args.validate, limit=args.limit, savePlots=args.savePlots, weightType=args.weightType, diff --git a/arrakis/utils/database.py b/arrakis/utils/database.py index 2d2d44df..af1ac864 100644 --- a/arrakis/utils/database.py +++ b/arrakis/utils/database.py @@ -57,7 +57,7 @@ def get_db( epoch: int, username: Union[str, None] = None, password: Union[str, None] = None, -) -> Tuple[Collection, Collection, Collection,]: +) -> Tuple[Collection, Collection, Collection]: """Get MongoDBs Args: diff --git a/arrakis/utils/meta.py b/arrakis/utils/meta.py index 00b109fe..7c08dd2e 100644 --- a/arrakis/utils/meta.py +++ b/arrakis/utils/meta.py @@ -1,6 +1,7 @@ #!/usr/bin/env python """Generic program utilities""" +import importlib import warnings from itertools import zip_longest @@ -11,6 +12,24 @@ warnings.simplefilter("ignore", category=AstropyWarning) +# From https://stackoverflow.com/questions/1176136/convert-string-to-python-class-object +def class_for_name(module_name: str, class_name: str) -> object: + """Returns a class object given a module name and class name + + Args: + module_name (str): Module name + class_name (str): Class name + + Returns: + object: Class object + """ + # load the module, will raise ImportError if module cannot be loaded + m = importlib.import_module(module_name) + # get the class, will raise AttributeError if class cannot be found + c = getattr(m, class_name) + return c + + # stolen from https://stackoverflow.com/questions/32954486/zip-iterators-asserting-for-equal-length-in-python def zip_equal(*iterables): sentinel = object() diff --git a/environment.yml b/environment.yml index e09e83a9..274a6f16 100644 --- a/environment.yml +++ b/environment.yml @@ -1,17 +1,17 @@ -name: arrakis +name: arrakis310 channels: - astropy - conda-forge - defaults - pkgw-forge dependencies: -- python=3.8 +- python=3.10 - pip - future - numpy - scipy - pandas -- matplotlib +- matplotlib>=3.8 - setuptools - ipython - astropy>=4.3 diff --git a/pyproject.toml b/pyproject.toml index e48aa3ff..2fb7a7dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "arrakis" -version = "2.0.0" +version = "2.1.0" description = "Processing the SPICE." homepage = "https://research.csiro.au/racs/" repository = "https://github.com/AlecThomson/arrakis" @@ -37,7 +37,7 @@ dask_mpi = "*" FRion = {git = "https://github.com/CIRADA-Tools/FRion.git" } h5py = "*" ipython = "*" -matplotlib = "*" +matplotlib = "^3.8" numba = "*" numba_progress = "*" #mpi4py = "*" @@ -62,7 +62,7 @@ bokeh = "<3" prefect = "^2" prefect-dask = "*" RMTable = { git = "https://github.com/CIRADA-Tools/RMTable" } -rm-tools = {git = "https://github.com/AlecThomson/RM-Tools.git", branch="spiceracs_dev"} +rm-tools = {git = "https://github.com/CIRADA-Tools/RM-Tools"} PolSpectra = { git = "https://github.com/AlecThomson/PolSpectra.git", branch="spiceracs"} setuptools = "*" fixms = "^0.1.2" diff --git a/scripts/casda_prepare.py b/scripts/casda_prepare.py index 3d0c166f..8933832c 100755 --- a/scripts/casda_prepare.py +++ b/scripts/casda_prepare.py @@ -802,36 +802,17 @@ def cli(): logger.setLevel(logging.INFO) elif args.debug: logger.setLevel(logging.DEBUG) - - if args.mpi: - initialize( - interface=args.interface, - local_directory="/dev/shm", - ) - cluster = None - else: - cluster = LocalCluster( - n_workers=12, - processes=True, - threads_per_worker=1, - local_directory="/dev/shm", - ) - - with Client( - cluster, - ) as client: - logger.debug(f"{client=}") - main( - polcatf=args.polcat, - data_dir=args.data_dir, - prep_type=args.prep_type, - do_update_cubes=args.convert_cubes, - do_convert_spectra=args.convert_spectra, - do_convert_plots=args.convert_plots, - verbose=args.verbose, - batch_size=args.batch_size, - outdir=args.outdir, - ) + main( + polcatf=args.polcat, + data_dir=args.data_dir, + prep_type=args.prep_type, + do_update_cubes=args.convert_cubes, + do_convert_spectra=args.convert_spectra, + do_convert_plots=args.convert_plots, + verbose=args.verbose, + batch_size=args.batch_size, + outdir=args.outdir, + ) if __name__ == "__main__": diff --git a/scripts/compare_leakage.py b/scripts/compare_leakage.py index 1ae86b73..8b7dfa62 100644 --- a/scripts/compare_leakage.py +++ b/scripts/compare_leakage.py @@ -265,12 +265,6 @@ def cli(): args = parser.parse_args() - cluster = LocalCluster( - n_workers=10, - threads_per_worker=1, - ) - client = Client(cluster) - if args.verbose: logger.setLevel(logging.INFO) @@ -285,9 +279,6 @@ def cli(): snr_cut=args.snr, ) - client.close() - cluster.close() - if __name__ == "__main__": cli() diff --git a/submit/test_image.py b/submit/test_image.py index 59b0a4bb..974617c8 100755 --- a/submit/test_image.py +++ b/submit/test_image.py @@ -37,9 +37,6 @@ def main(): ) cluster.scale(72) logger.debug(f"Submitted scripts will look like: \n {cluster.job_script()}") - # # exit() - # cluster = LocalCluster(n_workers=10, threads_per_worker=1) - # cluster.adapt(minimum=1, maximum=36) client = Client(cluster)