Skip to content

Commit

Permalink
Merge pull request #70 from AlecThomson/ddos
Browse files Browse the repository at this point in the history
Reduce number of tasks
  • Loading branch information
AlecThomson authored May 10, 2024
2 parents 560510f + 250b9bc commit 8e0e653
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 48 deletions.
4 changes: 2 additions & 2 deletions arrakis/configs/rm_petrichor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ cluster_kwargs:
silence_logs: 'info'
worker_extra_args: ["--memory-limit", "128GiB"]
adapt_kwargs:
minimum: 108
maximum: 108
minimum: 1
maximum: 512
wait_count: 20
target_duration: "5s"
interval: "10s"
86 changes: 66 additions & 20 deletions arrakis/frion.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ class Prediction(Struct):
update: pymongo.UpdateOne


class FrionResults(Struct):
prediction: Prediction
correction: pymongo.UpdateOne


@task(name="FRion correction")
def correct_worker(
beam: Dict, outdir: str, field: str, prediction: Prediction, island: dict
Expand Down Expand Up @@ -179,6 +184,50 @@ def index_beams(island: dict, beams: List[dict]) -> dict:
return beam


# We reduce the inner loop to a serial call
# This is to avoid overwhelming the Prefect server
@task(name="FRion loop")
def serial_loop(
island: dict,
field: str,
beam: dict,
start_time: Time,
end_time: Time,
freq_hz_array: np.ndarray,
cutdir: Path,
plotdir: Path,
ionex_server: str,
ionex_prefix: str,
ionex_proxy_server: Optional[str],
ionex_formatter: Optional[Union[str, Callable]],
ionex_predownload: bool,
) -> FrionResults:
prediction = predict_worker.fn(
island=island,
field=field,
beam=beam,
start_time=start_time,
end_time=end_time,
freq=freq_hz_array,
cutdir=cutdir,
plotdir=plotdir,
server=ionex_server,
prefix=ionex_prefix,
proxy_server=ionex_proxy_server,
formatter=ionex_formatter,
pre_download=ionex_predownload,
)
correction = correct_worker.fn(
beam=beam,
outdir=cutdir,
field=field,
prediction=prediction,
island=island,
)

return FrionResults(prediction=prediction, correction=correction)


@flow(name="FRion")
def main(
field: str,
Expand Down Expand Up @@ -312,42 +361,39 @@ def main(
pre_download=ionex_predownload,
)

predictions = []
corrections = []
frion_results = []
assert len(islands) == len(beams_cor), "Islands and beams must be the same length"
for island, beam in tqdm(
zip(islands, beams_cor),
desc="Submitting tasks",
file=TQDM_OUT,
total=len(islands),
):
prediction = predict_worker.submit(
frion_result = serial_loop.submit(
island=island,
field=field,
beam=beam,
start_time=start_time,
end_time=end_time,
freq=freq.to(u.Hz).value,
freq_hz_array=freq.to(u.Hz).value,
cutdir=cutdir,
plotdir=plotdir,
server=ionex_server,
prefix=ionex_prefix,
proxy_server=ionex_proxy_server,
formatter=ionex_formatter,
pre_download=ionex_predownload,
)
predictions.append(prediction)
correction = correct_worker.submit(
beam=beam,
outdir=cutdir,
field=field,
prediction=prediction,
island=island,
ionex_server=ionex_server,
ionex_prefix=ionex_prefix,
ionex_proxy_server=ionex_proxy_server,
ionex_formatter=ionex_formatter,
ionex_predownload=ionex_predownload,
)
corrections.append(correction)
frion_results.append(frion_result)

predictions = []
corrections = []
for result in frion_results:
predictions.append(result.result().prediction)
corrections.append(result.result().correction)

updates_arrays = [p.result().update for p in predictions]
updates = [c.result() for c in corrections]
updates_arrays = [p.update for p in predictions]
updates = corrections
if database:
logger.info("Updating beams database...")
db_res = beams_col.bulk_write(updates, ordered=False)
Expand Down
71 changes: 49 additions & 22 deletions arrakis/linmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,42 @@ def get_yanda(version="1.3.0") -> str:
return image


# We reduce the inner loop to a serial call
# This is to avoid overwhelming the Prefect server
@task(name="LINMOS loop")
def serial_loop(
field: str,
beams_row: Tuple[int, pd.Series],
stokeslist: List[str],
cutdir: Path,
holofile: Path,
image: Path,
) -> List[Optional[pymongo.UpdateOne]]:
results = []
for stoke in stokeslist:
image_path = find_images.fn(
field=field,
beams_row=beams_row,
stoke=stoke.capitalize(),
datadir=cutdir,
)
parset = genparset.fn(
image_paths=image_path,
stoke=stoke.capitalize(),
datadir=cutdir,
holofile=holofile,
)
result = linmos.fn(
parset=parset,
fieldname=field,
image=str(image),
holofile=holofile,
)
results.append(result)

return results


@flow(name="LINMOS")
def main(
field: str,
Expand Down Expand Up @@ -370,28 +406,19 @@ def main(
desc="Submitting tasks for LINMOS",
file=TQDM_OUT,
):
for stoke in stokeslist:
image_path = find_images.submit(
field=field,
beams_row=beams_row,
stoke=stoke.capitalize(),
datadir=cutdir,
)
parset = genparset.submit(
image_paths=image_path,
stoke=stoke.capitalize(),
datadir=cutdir,
holofile=holofile,
)
result = linmos.submit(
parset=parset,
fieldname=field,
image=str(image),
holofile=holofile,
)
results.append(result)

updates = [f.result() for f in results]
sub_results = serial_loop.submit(
field=field,
beams_row=beams_row,
stokeslist=stokeslist,
cutdir=cutdir,
holofile=holofile,
image=image,
)
results.append(sub_results)

updates_lists: List[list] = [f.result() for f in results]
# Flatten
updates = [u for ul in updates_lists for u in ul]
updates = [u for u in updates if u is not None]
logger.info("Updating database...")
db_res = beams_col.bulk_write(updates, ordered=False)
Expand Down
12 changes: 8 additions & 4 deletions arrakis/rmsynth_oncuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,17 +890,21 @@ def main(
# exit()
query_1d = {
"$and": [
{"Source_ID": {"$in": island_ids}},
{"Gaussian_ID": {"$in": component_ids}},
{"rm_outputs_1d": {"$exists": True}},
]
}
test_count = comp_col.count_documents(query_1d)
if test_count == 0:
if test_count < n_comp:
# Initialize the field
comp_col.update_many(
{"Source_ID": {"$in": island_ids}},
result = comp_col.update_many(
{
"Gaussian_ID": {"$in": component_ids},
"rm_outputs_1d": {"$exists": False},
},
{"$set": {"rm_outputs_1d": [{"field": save_name, "rmsynth1d": False}]}},
)
logger.info(pformat(result.raw_result))

update_1d = {
"field": save_name,
Expand Down

0 comments on commit 8e0e653

Please sign in to comment.