Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a decorator function to extract options from Strategy #180

Merged
merged 18 commits into from
Oct 12, 2024
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Change log

# dev

- added `wrapper_options_from_strategy` decorator helper function

# 0.2.7

- added in convolving of cubes to common resolution across channels
Expand Down
2 changes: 1 addition & 1 deletion flint/archive.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def get_archive_options_from_yaml(strategy_yaml_path: Path) -> Dict[str, Any]:
Dict[str, Any]: Loaded options for ArchiveOptions
"""
archive_options = get_options_from_strategy(
strategy=strategy_yaml_path, mode="archive", round="initial"
strategy=strategy_yaml_path, mode="archive", round_info="initial"
)

logger.info(f"{archive_options=}")
Expand Down
123 changes: 103 additions & 20 deletions flint/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
throughout the pipeline.
"""

import inspect
import shutil
from argparse import ArgumentParser
from pathlib import Path
from typing import Any, Dict, Optional, Union
from typing import Any, Callable, Dict, ParamSpec, Optional, TypeVar, Union

from click import MissingParameter
import yaml

from flint.imager.wsclean import WSCleanOptions
Expand Down Expand Up @@ -220,7 +222,7 @@ def get_image_options_from_yaml(
def get_options_from_strategy(
strategy: Union[Strategy, None, Path],
mode: str = "wsclean",
round: Union[str, int] = "initial",
round_info: Union[str, int] = "initial",
max_round_override: bool = True,
operation: Optional[str] = None,
) -> Dict[Any, Any]:
Expand All @@ -239,7 +241,7 @@ def get_options_from_strategy(
Args:
strategy (Union[Strategy,None,Path]): A loaded instance of a strategy file. If `None` is provided then an empty dictionary is returned. If `Path` attempt to load the strategy file.
mode (str, optional): Which set of options to load. Typical values are `wsclean`, `gaincal` and `masking`. Defaults to "wsclean".
round (Union[str, int], optional): Which round to load options for. May be `initial` or an `int` (which indicated a self-calibration round). Defaults to "initial".
round_info (Union[str, int], optional): Which round to load options for. May be `initial` or an `int` (which indicated a self-calibration round). Defaults to "initial".
max_round_override (bool, optional): Check whether an integer round number is recorded. If it is higher than the largest self-cal round specified, set it to the last self-cal round. If False this is not performed. Defaults to True.
operation (Optional[str], optional): Get options related to a specific operation. Defaults to None.

Expand All @@ -260,13 +262,17 @@ def get_options_from_strategy(
assert isinstance(
strategy, (Strategy, dict)
), f"Unknown input strategy type {type(strategy)}"
assert round == "initial" or isinstance(
round, int
), f"{round=} not a known value or type. "
assert round_info == "initial" or isinstance(
round_info, int
), f"{round_info=} not a known value or type. "

# Override the round if requested
if isinstance(round, int) and max_round_override and "selfcal" in strategy.keys():
round = min(round, max(strategy["selfcal"].keys()))
if (
isinstance(round_info, int)
and max_round_override
and "selfcal" in strategy.keys()
):
round_info = min(round_info, max(strategy["selfcal"].keys()))

# step one, get the defaults
options = dict(**strategy["defaults"][mode]) if mode in strategy["defaults"] else {}
Expand All @@ -284,16 +290,19 @@ def get_options_from_strategy(
)
if mode in strategy[operation]:
update_options = dict(**strategy[operation][mode])
elif round == "initial":
elif round_info == "initial":
# separate function to avoid a missing mode from raising value error
if mode in strategy["initial"]:
update_options = dict(**strategy["initial"][mode])
elif isinstance(round, int):
elif isinstance(round_info, int):
# separate function to avoid a missing mode from raising value error
if round in strategy["selfcal"] and mode in strategy["selfcal"][round]:
update_options = dict(**strategy["selfcal"][round][mode])
if (
round_info in strategy["selfcal"]
and mode in strategy["selfcal"][round_info]
):
update_options = dict(**strategy["selfcal"][round_info][mode])
else:
raise ValueError(f"{round=} not recognised.")
raise ValueError(f"{round_info=} not recognised.")

if update_options:
logger.debug(f"Updating options with {update_options=}")
Expand All @@ -302,6 +311,80 @@ def get_options_from_strategy(
return options


P = ParamSpec("P")
T = TypeVar("T")


def wrapper_options_from_strategy(update_options_keyword: str):
"""Decorator intended to allow options to be pulled from the
strategy file when function is called. See ``get_options_from_strategy``
for options that this function enables.

``update_options_keyword`` specifies the name of the
keyword argument that the options extracted from the strategy
file will be passed to.

Should `strategy` ne set to ``None`` then the function
will be called without any options being extracted.

Args:
update_options_keyword (str): The keyword option to update from the wrapped function
"""

def _wrapper(fn: Callable[P, T]) -> Callable[P, T]:
"""Decorator intended to allow options to be pulled from the
strategy file when function is called. See ``get_options_from_strategy``
for options that this function enables.

Args:
fn (Callable): The callable function that will be assigned the additional keywords

Returns:
Callable: The updated function
"""
signature = inspect.signature(fn)
if update_options_keyword not in signature.parameters:
raise MissingParameter(
f"{update_options_keyword=} not in {signature.parameters} of {fn.__name__}"
)

# Don't use functools.wraps. It does something to the expected args/kwargs that makes
# prefect confuxed, wherein it throws an error saying the strategy, mode, round options
# are not part of the wrappede fn's function signature.
def wrapper(
strategy: Union[Strategy, None, Path] = None,
mode: str = "wsclean",
round_info: Union[str, int] = "initial",
max_round_override: bool = True,
operation: Optional[str] = None,
*args: P.args,
**kwargs: P.kwargs,
) -> T:
if update_options_keyword in kwargs:
logger.info(
f"{update_options_keyword} explicitly passed to {fn.__name__}. Ignoring attempts to load strategy file. "
)
elif strategy:
update_options = get_options_from_strategy(
strategy=strategy,
mode=mode,
round_info=round_info,
max_round_override=max_round_override,
operation=operation,
)
logger.info(f"Adding extracted options to {update_options_keyword}")
kwargs[update_options_keyword] = update_options

return fn(*args, **kwargs)

# Keep the function name and docs correct
wrapper.__name__ = fn.__name__
wrapper.__doc__ = fn.__doc__
return wrapper # type: ignore

return _wrapper


def verify_configuration(input_strategy: Strategy, raise_on_error: bool = True) -> bool:
"""Perform basic checks on the configuration file

Expand Down Expand Up @@ -336,7 +419,7 @@ def verify_configuration(input_strategy: Strategy, raise_on_error: bool = True)
else:
for key in input_strategy["initial"].keys():
options = get_options_from_strategy(
strategy=input_strategy, mode=key, round="initial"
strategy=input_strategy, mode=key, round_info="initial"
)
try:
_ = MODE_OPTIONS_MAPPING[key](**options)
Expand All @@ -351,16 +434,16 @@ def verify_configuration(input_strategy: Strategy, raise_on_error: bool = True)
if not all([isinstance(i, int) for i in round_keys]):
errors.append("The keys into the self-calibration should be ints. ")

for round in round_keys:
for mode in input_strategy["selfcal"][round]:
for round_info in round_keys:
for mode in input_strategy["selfcal"][round_info]:
options = get_options_from_strategy(
strategy=input_strategy, mode=mode, round=round
strategy=input_strategy, mode=mode, round_info=round_info
)
try:
_ = MODE_OPTIONS_MAPPING[mode](**options)
except TypeError as typeerror:
errors.append(
f"{mode=} mode in {round=} incorrectly formed. {typeerror} "
f"{mode=} mode in {round_info=} incorrectly formed. {typeerror} "
)

for operation in KNOWN_OPERATIONS:
Expand Down Expand Up @@ -462,8 +545,8 @@ def create_default_yaml(
if selfcal_rounds:
logger.info(f"Creating {selfcal_rounds} self-calibration rounds. ")
selfcal: Dict[int, Any] = {}
for round in range(1, selfcal_rounds + 1):
selfcal[round] = {
for selfcal_round in range(1, selfcal_rounds + 1):
selfcal[selfcal_round] = {
"wsclean": {},
"gaincal": {},
"masking": {},
Expand Down
10 changes: 7 additions & 3 deletions flint/prefect/common/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
select_aosolution_for_ms,
)
from flint.coadd.linmos import LinmosCommand, linmos_images
from flint.configuration import wrapper_options_from_strategy
from flint.convol import (
BeamShape,
convolve_images,
Expand Down Expand Up @@ -224,9 +225,10 @@ def task_zip_ms(in_item: WSCleanCommand) -> Path:


@task
@wrapper_options_from_strategy(update_options_keyword="update_gain_cal_options")
def task_gaincal_applycal_ms(
ms: Union[MS, WSCleanCommand],
round: int,
selfcal_round: int,
casa_container: Path,
update_gain_cal_options: Optional[Dict[str, Any]] = None,
archive_input_ms: bool = False,
Expand All @@ -238,7 +240,7 @@ def task_gaincal_applycal_ms(

Args:
ms (Union[MS, WSCleanCommand]): A resulting wsclean output. This is used purely to extract the ``.ms`` attribute.
round (int): Counter indication which self-calibration round is being performed. A name is included based on this.
selfcal_round (int): Counter indication which self-calibration round is being performed. A name is included based on this.
casa_container (Path): A path to a singularity container with CASA tooling.
update_gain_cal_options (Optional[Dict[str, Any]], optional): Options used to overwrite the default ``gaincal`` options. Defaults to None.
archive_input_ms (bool, optional): If True the input measurement set is zipped. Defaults to False.
Expand All @@ -263,7 +265,7 @@ def task_gaincal_applycal_ms(

return gaincal_applycal_ms(
ms=ms,
round=round,
round=selfcal_round,
casa_container=casa_container,
update_gain_cal_options=update_gain_cal_options,
archive_input_ms=archive_input_ms,
Expand All @@ -274,6 +276,7 @@ def task_gaincal_applycal_ms(


@task
@wrapper_options_from_strategy(update_options_keyword="update_wsclean_options")
def task_wsclean_imager(
in_ms: Union[ApplySolutions, MS],
wsclean_container: Path,
Expand Down Expand Up @@ -813,6 +816,7 @@ def _create_convolve_linmos_cubes(


@task
@wrapper_options_from_strategy(update_options_keyword="update_masking_options")
def task_create_image_mask_model(
image: Union[LinmosCommand, ImageSet, WSCleanCommand],
image_products: AegeanOutputs,
Expand Down
54 changes: 28 additions & 26 deletions flint/prefect/flows/continuum_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,23 +277,27 @@ def process_science_fields(
logger.info("No wsclean container provided. Rerutning. ")
return

wsclean_init = get_options_from_strategy(
strategy=strategy, mode="wsclean", round="initial"
)

if field_options.potato_container:
# The call into potato peel task has two potential update option keywords.
# So for the moment we will not use the task decorated version.
potato_wsclean_init = get_options_from_strategy(
strategy=strategy, mode="wsclean", round_info="initial"
)
preprocess_science_mss = task_potato_peel.map(
ms=preprocess_science_mss,
potato_container=field_options.potato_container,
update_wsclean_options=unmapped(wsclean_init),
update_wsclean_options=unmapped(potato_wsclean_init),
)

stokes_v_mss = preprocess_science_mss
wsclean_cmds = task_wsclean_imager.map(
in_ms=preprocess_science_mss,
wsclean_container=field_options.wsclean_container,
update_wsclean_options=unmapped(wsclean_init),
)
strategy=unmapped(strategy),
mode="wsclean",
round_info="initial",
) # type: ignore

# TODO: This should be waited!
beam_summaries = task_create_beam_summary.map(
ms=preprocess_science_mss, imageset=wsclean_cmds
Expand Down Expand Up @@ -350,27 +354,22 @@ def process_science_fields(
with tags(f"selfcal-{current_round}"):
final_round = current_round == field_options.rounds

gain_cal_options = get_options_from_strategy(
strategy=strategy, mode="gaincal", round=current_round
)
wsclean_options = get_options_from_strategy(
strategy=strategy, mode="wsclean", round=current_round
)

skip_gaincal_current_round = consider_skip_selfcal_on_round(
current_round=current_round,
skip_selfcal_on_rounds=field_options.skip_selfcal_on_rounds,
)

cal_mss = task_gaincal_applycal_ms.map(
ms=wsclean_cmds,
round=current_round,
update_gain_cal_options=unmapped(gain_cal_options),
selfcal_round=current_round,
archive_input_ms=field_options.zip_ms,
skip_selfcal=skip_gaincal_current_round,
rename_ms=field_options.rename_ms,
archive_cal_table=True,
casa_container=field_options.casa_container,
strategy=unmapped(strategy),
mode="gaincal",
round_info=current_round,
wait_for=[
field_summary
], # To make sure field summary is created with unzipped MSs
Expand All @@ -383,12 +382,11 @@ def process_science_fields(
mask_rounds=field_options.use_beam_mask_rounds,
allow_beam_masks=field_options.use_beam_masks,
):
masking_options = get_options_from_strategy(
strategy=strategy, mode="masking", round=current_round
)
# The is intended to only run the beam wise aegean if it has not already
# been done. Immedidatedly after the first round of shallow cleaning
# aegean could be run.
# Early versions of the masking procedure required aegean outputs
# to construct the sginal images. Since aegean is run outside of
# this self-cal loop once already, we can skip their creation on
# the first loop
# TODO: the aegean outputs are only needed should the signal image be needed
beam_aegean_outputs = (
task_run_bane_and_aegean.map(
image=wsclean_cmds,
Expand All @@ -400,15 +398,19 @@ def process_science_fields(
fits_beam_masks = task_create_image_mask_model.map(
image=wsclean_cmds,
image_products=beam_aegean_outputs,
update_masking_options=unmapped(masking_options),
)
strategy=unmapped(strategy),
mode="masking",
round_info=current_round,
) # type: ignore

wsclean_cmds = task_wsclean_imager.map(
in_ms=cal_mss,
wsclean_container=field_options.wsclean_container,
update_wsclean_options=unmapped(wsclean_options),
fits_mask=fits_beam_masks,
)
strategy=unmapped(strategy),
mode="wsclean",
round_info=current_round,
) # type: ignore
archive_wait_for.extend(wsclean_cmds)

# Do source finding on the last round of self-cal'ed images
Expand Down
4 changes: 2 additions & 2 deletions flint/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ def get_packaged_resource_path(package: str, filename: str) -> Path:
except ImportWarning:
from importlib import resources as importlib_resources

with importlib_resources.files(package) as p:
full_path = Path(p) / filename
p = importlib_resources.files(package)
full_path = Path(p) / filename # type: ignore

logger.debug(f"Resolved {full_path=}")

Expand Down
Loading
Loading