Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preflagger options exposed and new modes #69

Merged
merged 22 commits into from
Mar 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 105 additions & 6 deletions flint/bptools/preflagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from pathlib import Path
from typing import NamedTuple, Optional, Tuple
from typing import List, NamedTuple, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -405,6 +405,9 @@ def flag_mean_residual_amplitude(
bool: Whether the data should be considered bad. True if it is bad, False if otherwise.
"""

if not np.any(np.isfinite(complex_gains)):
return True

amplitudes = np.abs(complex_gains)
idxs = np.arange(amplitudes.shape[0])
mask = np.isfinite(amplitudes)
Expand Down Expand Up @@ -443,7 +446,7 @@ def flag_mean_residual_amplitude(


def flag_mean_xxyy_amplitude_ratio(
xx_complex_gains: np.ndarray, yy_complex_gains, fraction: float = 2.0
xx_complex_gains: np.ndarray, yy_complex_gains, tolerance: float = 0.1
) -> bool:
"""Will robust compute through an iterative sigma-clipping procedure the
mean XX and YY gain amplitudes. The ratio of these means are computed,
Expand All @@ -457,7 +460,7 @@ def flag_mean_xxyy_amplitude_ratio(
Args:
xx_complex_gains (np.ndarray): The XX complex gains to be considered
yy_complex_gains (_type_): The YY complex gains to be considered
fraction (float, optional): The fraction used to distinguish a critical mean ratio threshold. Defaults to 2..
tolerance (float, optional): The tolerance used used to distinguish a critical mean ratio threshold. Defaults to 0.10.

Returns:
bool: Whether data should be flagged (True) or not (False)
Expand All @@ -482,13 +485,109 @@ def flag_mean_xxyy_amplitude_ratio(

result = (
not np.isfinite(mean_gain_ratio)
or mean_gain_ratio < (1.0 / fraction)
or mean_gain_ratio > fraction
or mean_gain_ratio < (1.0 - tolerance)
or mean_gain_ratio > (1.0 + tolerance)
)

if result:
logger.warning(
f"Failed the mean gain ratio test: {xx_mean=} {yy_mean=} {mean_gain_ratio=} "
f"Failed the mean gain ratio test: {xx_mean=} {yy_mean=} {mean_gain_ratio=} {tolerance=}"
)

return result


def construct_mesh_ant_flags(mask: np.ndarray) -> np.ndarray:
"""Construct a mask that will accumulate the flags across
all antennas. The input mask array should be boolean and
of shape (ant, channels, pol), where `True` means flagged.

If an antenna is completely flagged it is ignored as the
statistics are collected

Args:
mask (np.ndarray): Input array denoting which items are flagged.

Returns:
np.ndarray: Output array where antennas have common sets of flags
"""

assert (
len(mask.shape) == 3
), f"Expect array of shape (ant, chnnel, pol), received {mask.shape=}"
accumulate_mask = np.zeros_like(mask[0], dtype=bool)

nant = mask.shape[0]
logger.info(f"Accumulating flagged channels over {nant=} antenna")

empty_ants: List[int] = []

# TODO: This can be replaced with numpy broadcasting

for ant in range(nant):
ant_mask = mask[ant]
if np.all(ant_mask):
empty_ants.append(ant)
continue

accumulate_mask = accumulate_mask | ant_mask

logger.info(f"Flags in accumulated mask: {np.sum(accumulate_mask)}")

result_mask = np.zeros_like(mask, dtype=bool)

for ant in range(nant):
if ant in empty_ants:
result_mask[ant, :, :] = True
else:
result_mask[ant] = accumulate_mask

return result_mask


def construct_jones_over_max_amp_flags(
complex_gains: np.ndarray, max_amplitude: float
) -> np.ndarray:
"""Construct and return a mask that would flag an entire Jones
should there be an element whose amplitude is above a flagging
threshold

Args:
complex_gains (np.ndarray): Complex gains that will have a mask constructed
max_amplitude (float): The flagging threshold, any Jones with a member above this will be flagged

Returns:
np.ndarray: Boolean array of equal shape to `complex_gains`, with `True` indicating a flag
"""

assert (
complex_gains.shape[-1] == 4
), f"Expected last dimension to be length 4, received {complex_gains.shape=}"

logger.info(f"Creating mask for Jones with amplitudes of {max_amplitude=}")
complex_gains = complex_gains.copy()

original_shape = complex_gains.shape

# Calculate tehe amplitudes of each of the complex numbers
# and construct the initial mask
amplitudes = np.abs(complex_gains)
mask = amplitudes > max_amplitude

# Compress all but the last dimension into a single
# rank so we can easily broadcast over
mask = mask.reshape((-1, 4))

# Now broadcast like a pirate
flag_jones = np.any(mask, axis=1)
mask[flag_jones, :] = True

# Convert back to original shape
mask = mask.reshape(original_shape)

no_flagged = np.sum(mask)

if no_flagged > 0:
logger.warning(f"{no_flagged} items flagged with {max_amplitude=}")

return mask
100 changes: 70 additions & 30 deletions flint/calibrate/aocalibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import numpy as np

from flint.bptools.preflagger import (
construct_jones_over_max_amp_flags,
construct_mesh_ant_flags,
flag_mean_residual_amplitude,
flag_mean_xxyy_amplitude_ratio,
flag_outlier_phase,
Expand All @@ -36,6 +38,7 @@
from flint.ms import MS, consistent_ms, get_beam_from_ms
from flint.naming import get_aocalibrate_output_path
from flint.sclient import run_singularity_command
from flint.utils import create_directory


class CalibrateOptions(NamedTuple):
Expand Down Expand Up @@ -855,22 +858,42 @@ class FlaggedAOSolution(NamedTuple):

def flag_aosolutions(
solutions_path: Path,
ref_ant: Optional[int] = -1,
ref_ant: int = -1,
flag_cut: float = 3,
plot_dir: Optional[Path] = None,
out_solutions_path: Optional[Path] = None,
smooth_solutions: bool = False,
plot_solutions_throughout: bool = True,
smooth_window_size: int = 16,
smooth_polynomial_order: int = 4,
mean_ant_tolerance: float = 0.2,
mesh_ant_flags: bool = False,
max_gain_amplitude: Optional[float] = None,
) -> FlaggedAOSolution:
"""Will open a previously solved ao-calibrate solutions file and flag additional channels and antennae.

There are currently two main stages. The first will attempt to search for channels where the the phase of the
There are a number of distinct operations applied to the data, which are
presented in order they are applied.

If `mesh_ant_flags` is `True`, channels flagged from on channel on a single
antenna will be applied to all (unless an antenna is completely flagged).
This happens before any other operation,.

If `max_gain_amplitude` is not `None` than any Jones with an element
whose amplitude is above the set value will be flagged.

Next, an attempt is made to search for channels where the the phase of the
gain solution are outliers. The phase over frequency is first unwrapped (delay solved for) before the flagging
statistics are computed.

The second stage will flag an entire antenna if more then 80 percent of the flags for a polarisation are flagged.
If an antenna is over 80% flagged then it is completely removed.

A low order polynomial (typically order 5) is fit to the amplitudes of the
Gx and Gy, and if the residuals are sufficently high then the antenna will
be flagged.

If the mean ratio of the Gx and Gy amplitudes for an antenna are higher
then `mean_ant_tolerance` then the antenna will be flagged.

Keywords that with the `smooth` prefix are passed to the `smooth_bandpass_complex_gains` function.

Expand All @@ -884,23 +907,24 @@ def flag_aosolutions(
plot_solutions_throughout (bool, Optional): If True, the solutions will be plotted at different stages of processing. Defaults to True.
smooth_window_size (int, optional): The size of the window function of the savgol filter. Passed directly to savgol. Defaults to 16.
smooth_polynomial_order (int, optional): The order of the polynomial of the savgol filter. Passed directly to savgol. Defaults to 4.
mean_ant_tolerance (float, optional): Tolerance of the mean x/y antenna gain ratio test before the antenna is flagged. Defaults to 0.2.
mesh_ant_flags (bool, optional): If True, a channel is flagged across all antenna if it is flagged for any antenna. Performed before other flagging operations. Defaults to False.
max_gain_amplitude (Optional[float], optional): If not None, flag the Jones if an antenna has a amplitude gain above this value. Defaults to 10.

Returns:
FlaggedAOSolution: Path to the updated solutions file, intermediate solution files and plots along the way
"""
# TODO: This should be broken down into separate stages. Way too large of a function.
# TODO: This pirate needs to cull some of this logic out, likely not needed
# and dead

solutions = AOSolutions.load(path=solutions_path)
title = solutions_path.name

pols = {0: "XX", 1: "XY", 2: "YX", 3: "YY"}

if plot_dir is not None and not plot_dir.exists():
logger.info(f"Creating {str(plot_dir)}")
try:
plot_dir.mkdir(parents=True)
except Exception as e:
logger.error(f"Failed to create {str(plot_dir)} {e}.")
if plot_dir:
create_directory(directory=plot_dir)

# Note that although the solutions variable (an instance of AOSolutions) is immutable,
# which includes the reference to the numpy array, the _actual_ numpy array is! So,
Expand All @@ -919,19 +943,34 @@ def flag_aosolutions(
output_plots = plot_solutions(solutions=solutions_path, ref_ant=ref_ant)
plots.extend(output_plots)

if mesh_ant_flags:
logger.info("Combining antenna flags")
mask = np.zeros_like(bandpass, dtype=bool)

for time in range(solutions.nsol):
mask[time] = construct_mesh_ant_flags(mask=~np.isfinite(bandpass[time]))

bandpass[mask] = np.nan

if max_gain_amplitude:
mask = construct_jones_over_max_amp_flags(
complex_gains=bandpass, max_amplitude=max_gain_amplitude
)
bandpass[mask] = np.nan

for time in range(solutions.nsol):
ref_bandpass = divide_bandpass_by_ref_ant_preserve_phase(
complex_gains=bandpass[time], ref_ant=ref_ant
)
for pol in (0, 3):
logger.info(f"Processing {pols[pol]} polarisation")
ref_ant_gains = bandpass[time, ref_ant, :, pol]
if np.sum(np.isfinite(ref_ant_gains)) == 0:
raise ValueError(f"The ref_ant={ref_ant} is completely bad. ")


for ant in range(solutions.nant):
if ant == ref_ant:
logger.info(f"Skipping reference antenna = ant{ref_ant:02}")
continue

ant_gains = bandpass[time, ant, :, pol] / ref_ant_gains
ant_gains = ref_bandpass[ant, :, pol]
plot_title = f"{title} - ant{ant:02d} - {pols[pol]}"
ouput_path = (
plot_dir / f"{title}.ant{ant:02d}.{pols[pol]}.png"
Expand All @@ -950,47 +989,48 @@ def flag_aosolutions(
plot_title=plot_title,
plot_path=ouput_path,
)
bandpass[time, ant, phase_outlier_result.outlier_mask, pol] = np.nan
bandpass[time, ant, phase_outlier_result.outlier_mask, :] = np.nan
except PhaseOutlierFitError:
# This is raised if the fit failed to converge, or some other nasty.
bandpass[time, ant, :, pol] = np.nan
bandpass[time, ant, :, :] = np.nan

for time in range(solutions.nsol):
for pol in (0, 3):
for ant in range(solutions.nant):
# Flag all solutions for this (ant,pol) if more than 80% are flagged
if flags_over_threshold(
flags=~np.isfinite(bandpass[time, ant, :, pol]),
thresh=0.8,
ant_idx=ant,
):
logger.info(
f"Flagging all solutions across {pols[pol]} for ant{ant:02d}, too many flagged channels."
f"Flagging all solutions across ant{ant:02d}, too many flagged channels."
)
bandpass[time, ant, :, pol] = np.nan
bandpass[time, ant, :, :] = np.nan

complex_gains = bandpass[time, ant, :, pol]
if any(np.isfinite(complex_gains)) and flag_mean_residual_amplitude(
complex_gains=complex_gains
):
if flag_mean_residual_amplitude(complex_gains=complex_gains):
logger.info(
f"Flagging all solutions across {pols[pol]} for ant{ant:02d}, mean residual amplitudes high"
f"Flagging all solutions for ant{ant:02d}, mean residual amplitudes high"
)
bandpass[time, ant, :, pol] = np.nan
bandpass[time, ant, :, :] = np.nan

flagged = ~np.isfinite(bandpass[time, ant, :, pol])
logger.info(
f"{ant=:02d}, pol={pols[pol]}, flagged {np.sum(flagged) / ant_gains.shape[0] * 100.:.2f}%"
)

for time in range(solutions.nsol):
ref_ant_gains = bandpass[time, ref_ant]
bandpass_phased_referenced = divide_bandpass_by_ref_ant_preserve_phase(
complex_gains=bandpass[time], ref_ant=ref_ant
)
# This loop will flag based on stats across different polarisations
for ant in range(solutions.nant):
# We need to skip the case of flagging on the reference antenna, I think.
if ref_ant == ant:
continue

ant_gains = bandpass[time, ant] / ref_ant_gains
ant_gains = bandpass_phased_referenced[ant]
if flag_mean_xxyy_amplitude_ratio(
xx_complex_gains=ant_gains[:, 0], yy_complex_gains=ant_gains[:, 3]
xx_complex_gains=ant_gains[:, 0],
yy_complex_gains=ant_gains[:, 3],
tolerance=mean_ant_tolerance,
):
logger.info(f"{ant=} failed mean amplitude gain test. Flagging {ant=}.")
bandpass[time, ant, :, :] = np.nan
Expand Down
8 changes: 8 additions & 0 deletions flint/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class BandpassOptions(NamedTuple):
"""Path to the singularity calibrate container"""
expected_ms: int = 36
"""The expected number of measurement set files to find"""
smooth_solutions: bool = False
"""Will activate the smoothing of the bandpass solutions"""
smooth_window_size: int = 16
"""The width of the smoothing window used to smooth the bandpass solutions"""
smooth_polynomial_order: int = 4
Expand All @@ -33,6 +35,12 @@ class BandpassOptions(NamedTuple):
"""The number of times the bandpass will be calibrated, flagged, then recalibrated"""
minuv: Optional[float] = None
"""The minimum baseline length, in meters, for data to be included in bandpass calibration stage"""
preflagger_ant_mean_tolerance: float = 0.2
"""Tolerance that the mean x/y antenna gain ratio test before the antenna is flagged"""
preflagger_mesh_ant_flags: bool = False
"""Share channel flags from bandpass solutions between all antenna"""
preflagger_jones_max_amplitude: Optional[float] = None
"""Flag Jones matrix if any amplitudes with a Jones are above this value"""


class FieldOptions(NamedTuple):
Expand Down
Loading
Loading