Skip to content

Commit

Permalink
Lock threads for cutouts (#78)
Browse files Browse the repository at this point in the history
* Use threadlock

* Pass dryrun
  • Loading branch information
AlecThomson authored Sep 3, 2024
1 parent 1258f7f commit 11513bf
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions arrakis/cutout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,15 @@

import argparse
import logging
import time
import warnings
from concurrent.futures import ThreadPoolExecutor
import os
import multiprocessing as mp
from pathlib import Path
from pprint import pformat
from shutil import copyfile
from typing import List
from typing import NamedTuple as Struct
from typing import Optional, Set, TypeVar
from threading import Lock

import astropy.units as u
import numpy as np
Expand Down Expand Up @@ -117,6 +115,7 @@ def cutout_weight(


def cutout_image(
lock: Lock,
image_name: Path,
data_in_mem: np.ndarray,
old_header: fits.Header,
Expand Down Expand Up @@ -198,19 +197,15 @@ def cutout_image(
# Add source name to header for CASDA
fixed_header["OBJECT"] = source_id
if not dryrun:
if outfile.exists():
time.sleep(1)
outfile.unlink(missing_ok=True)
time.sleep(1)

fits.writeto(
outfile,
sub_data,
header=fixed_header,
overwrite=True,
output_verify="fix",
)
logger.info(f"Written to {outfile}")
with lock:
fits.writeto(
outfile,
sub_data,
header=fixed_header,
overwrite=True,
output_verify="fix",
)
logger.info(f"Written to {outfile}")

filename = outfile.parent / outfile.name
newvalues = {
Expand Down Expand Up @@ -334,7 +329,8 @@ def get_args(
)


def worker(
def make_cutout(
lock: Lock,
host: str,
epoch: int,
source: pd.Series,
Expand All @@ -350,6 +346,7 @@ def worker(
pad: float = 3,
username: Optional[str] = None,
password: Optional[str] = None,
dryrun: bool = False,
):
_, _, comp_col = get_db(
host=host, epoch=epoch, username=username, password=password
Expand All @@ -360,6 +357,7 @@ def worker(
outdir=outdir,
)
image_update = cutout_image(
lock=lock,
image_name=image_name,
data_in_mem=data_in_mem,
old_header=old_header,
Expand All @@ -370,7 +368,7 @@ def worker(
beam_num=beam_num,
stoke=stoke,
pad=pad,
dryrun=False,
dryrun=dryrun,
)
weight_update = cutout_weight(
image_name=image_name,
Expand All @@ -379,7 +377,7 @@ def worker(
field=field,
beam_num=beam_num,
stoke=stoke,
dryrun=False,
dryrun=dryrun,
)
return [image_update, weight_update]

Expand All @@ -399,6 +397,7 @@ def big_cutout(
username: Optional[str] = None,
password: Optional[str] = None,
limit: Optional[int] = None,
dryrun: bool = False,
) -> List[pymongo.UpdateOne]:
wild = f"image.restored.{stoke.lower()}*contcube*beam{beam_num:02}.conv.fits"
images = list(datadir.glob(wild))
Expand All @@ -423,14 +422,15 @@ def big_cutout(
sources = sources[:limit]

# Check for slurm cpus
max_workers = int(os.environ.get("SLURM_CPUS_PER_TASK", mp.cpu_count()))
updates: List[pymongo.UpdateOne] = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
lock = Lock()
with ThreadPoolExecutor() as executor:
futures = []
for _, source in sources.iterrows():
futures.append(
executor.submit(
worker,
make_cutout,
lock=lock,
host=host,
epoch=epoch,
source=source,
Expand All @@ -446,6 +446,7 @@ def big_cutout(
pad=pad,
username=username,
password=password,
dryrun=dryrun,
)
)
for future in tqdm(futures, file=TQDM_OUT, desc=f"Cutting {image_name}"):
Expand Down Expand Up @@ -572,6 +573,7 @@ def cutout_islands(
username=username,
password=password,
limit=limit,
dryrun=dryrun,
)
cuts.append(results)

Expand Down

0 comments on commit 11513bf

Please sign in to comment.