Skip to content

Commit

Permalink
Merge pull request #71 from tjgalvin/flaguvws
Browse files Browse the repository at this point in the history
Flagging zero'd uvws
  • Loading branch information
tjgalvin authored Mar 16, 2024
2 parents 6ee465c + 7baefb8 commit d20b18f
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 47 deletions.
2 changes: 1 addition & 1 deletion flint/calibrate/aocalibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ def flag_aosolutions(
)
for pol in (0, 3):
logger.info(f"Processing {pols[pol]} polarisation")

for ant in range(solutions.nant):
if ant == ref_ant:
logger.info(f"Skipping reference antenna = ant{ref_ant:02}")
Expand Down
52 changes: 26 additions & 26 deletions flint/data/aoflagger/ASKAP.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@ that is packaged in aoflagger. The main differences are:
aoflagger.require_min_version("3.0")

function execute(input)

--
-- Generic settings
--

-- What polarizations to flag? Default: input:get_polarizations() (=all that are in the input data)
-- Other options are e.g.:
-- { 'XY', 'YX' } to flag only XY and YX, or
-- { 'I', 'Q' } to flag only on Stokes I and Q
local flag_polarizations = input:get_polarizations()
local base_threshold = 2.5 -- lower means more sensitive detection

local base_threshold = 1.5 -- lower means more sensitive detection
-- How to flag complex values, options are: phase, amplitude, real, imaginary, complex
-- May have multiple values to perform detection multiple times
local flag_representations = { "amplitude" }
Expand All @@ -31,86 +31,86 @@ that is packaged in aoflagger. The main differences are:
local exclude_original_flags = false
local frequency_resize_factor = 2.0 -- Amount of "extra" smoothing in frequency direction
local transient_threshold_factor = 1.0 -- decreasing this value makes detection of transient RFI more aggressive

--
-- End of generic settings
--

local inpPolarizations = input:get_polarizations()

if(not exclude_original_flags) then
input:clear_mask()
end
-- For collecting statistics. Note that this is done after clear_mask(),
-- so that the statistics ignore any flags in the input data.
local copy_of_input = input:copy()

for ipol,polarization in ipairs(flag_polarizations) do

local pol_data = input:convert_to_polarization(polarization)

for _,representation in ipairs(flag_representations) do

-- 'data' is now in the desired representation e.g. amplitude
local data = pol_data:convert_to_complex(representation)
local original_data = data:copy()

for i=1,iteration_count-1 do
local threshold_factor = math.pow(threshold_factor_step, iteration_count-i)

local sumthr_level = threshold_factor * base_threshold
if(exclude_original_flags) then
aoflagger.sumthreshold_masked(data, original_data, sumthr_level, sumthr_level*transient_threshold_factor, true, true)
else
aoflagger.sumthreshold(data, sumthr_level, sumthr_level*transient_threshold_factor, true, true)
end

-- Do timestep & channel flagging
local chdata = data:copy()
aoflagger.threshold_timestep_rms(data, 3.5)
aoflagger.threshold_channel_rms(chdata, 3.0 * threshold_factor, true)
data:join_mask(chdata)

-- High pass filtering steps
data:set_visibilities(original_data)
if(exclude_original_flags) then
data:join_mask(original_data)
end

local resized_data = aoflagger.downsample(data, 3, frequency_resize_factor, true)
aoflagger.low_pass_filter(resized_data, 21, 31, 2.5, 5.0)
aoflagger.upsample(resized_data, data, 3, frequency_resize_factor)

-- In case this script is run from inside rfigui, calling
-- the following visualize function will add the current result
-- to the list of displayable visualizations.
-- If the script is not running inside rfigui, the call is ignored.
aoflagger.visualize(data, "Fit #"..i, i-1)

local tmp = original_data - data
tmp:set_mask(data)
data = tmp

aoflagger.visualize(data, "Residual #"..i, i+iteration_count)
aoflagger.set_progress((ipol-1)*iteration_count+i, #flag_polarizations*iteration_count )
end -- end of iterations

if(exclude_original_flags) then
aoflagger.sumthreshold_masked(data, original_data, base_threshold, base_threshold*transient_threshold_factor, true, true)
else
aoflagger.sumthreshold(data, base_threshold, base_threshold*transient_threshold_factor, true, true)
end
input:join_mask(data)
end -- end of complex representation iteration

-- Helper function used in the strategy
function contains(arr, val)
for _,v in ipairs(arr) do
if v == val then return true end
end
return false
end

if contains(inpPolarizations, polarization) then
if input:is_complex() then
local pol_data = input:convert_to_polarization(polarization)
Expand All @@ -120,19 +120,19 @@ that is packaged in aoflagger. The main differences are:
else
input:join_mask(data)
end

aoflagger.visualize(data, "Residual #"..iteration_count, 2*iteration_count)
aoflagger.set_progress(ipol, #flag_polarizations )
end -- end of polarization iterations

if(exclude_original_flags) then
aoflagger.scale_invariant_rank_operator_masked(input, copy_of_input, 0.2, 0.2)
else
aoflagger.scale_invariant_rank_operator(input, 0.2, 0.2)
end

aoflagger.threshold_timestep_rms(input, 4.0)

if input:is_complex() and input:has_metadata() then
-- This command will calculate a few statistics like flag% and stddev over
-- time, frequency and baseline and write those to the MS. These can be
Expand Down
60 changes: 54 additions & 6 deletions flint/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from flint.exceptions import MSError
from flint.logging import logger
from flint.ms import MS, check_column_in_ms, describe_ms
from flint.ms import MS, check_column_in_ms, critical_ms_interaction, describe_ms
from flint.sclient import run_singularity_command
from flint.utils import get_packaged_resource_path

Expand All @@ -28,6 +28,52 @@ class AOFlaggerCommand(NamedTuple):
"""The path to the aoflagging stategy file to use"""


def flag_ms_zero_uvws(ms: MS, chunk_size: int = 10000) -> MS:
"""Flag out the UVWs in a measurement set that have values of zero.
This happens when some data are flagged before it reaches the TOS.
A critical MS interaction scope is created to ensure if things fail
they are known.
Args:
ms (MS): Measurement set to flag
chunk_size (int, optional): The number of rows to flag at a tim. Defaults to 10000.
Returns:
MS: The flagged measurement set
"""

ms = MS.cast(ms)
logger.info(f"Flagging zero uvw's for {ms.path}")
row_idx = 0

# Rename the measurement set while it is being operated on
with critical_ms_interaction(input_ms=ms.path) as critical_ms_path:
with table(str(critical_ms_path), readonly=False, ack=False) as tab:
table_size = len(tab)

# so long as the row index is less than the table size there
# is another chunk to flag
while row_idx < (table_size - 1):
uvws = tab.getcol("UVW", startrow=row_idx, nrow=chunk_size)
flags = tab.getcol("FLAG", startrow=row_idx, nrow=chunk_size)

# Select records what the (u,v,w) are (0,0,0)
# Data in the shape (record, 3)
zero_uvws = np.all(uvws == 0, axis=1)
flags[zero_uvws, :] = True

# Put it back into place, update the counter for the next insertion
size = len(flags)
tab.putcol("FLAG", flags, startrow=row_idx, nrow=size)
row_idx += size

# Ensure changes written back to the MS
tab.flush()

return ms


def nan_zero_extreme_flag_ms(
ms: Union[Path, MS],
data_column: Optional[str] = None,
Expand Down Expand Up @@ -167,13 +213,12 @@ def run_aoflagger_cmd(aoflagger_cmd: AOFlaggerCommand, container: Path) -> None:
)


def flag_ms_aoflagger(ms: MS, container: Path, rounds: int = 1) -> MS:
def flag_ms_aoflagger(ms: MS, container: Path) -> MS:
"""Create and run an aoflagger command in a container
Args:
ms (MS): The measurement set with nominated column to flag
container (Path): The container with the aoflagger program
rounds (int, optional): Number of times to run the flagging. Defaults to 1.
Returns:
MS: Measurement set flagged with the appropriate column
Expand All @@ -182,9 +227,12 @@ def flag_ms_aoflagger(ms: MS, container: Path, rounds: int = 1) -> MS:
logger.info(f"Will flag column {ms.column} in {str(ms.path)}.")
aoflagger_cmd = create_aoflagger_cmd(ms=ms)

for i in range(rounds):
logger.info("Flagging command constructed. ")
run_aoflagger_cmd(aoflagger_cmd=aoflagger_cmd, container=container)
logger.info("Flagging command constructed. ")
run_aoflagger_cmd(aoflagger_cmd=aoflagger_cmd, container=container)

# TODO: This should be moved to the aoflagger lua file once it has
# been implemented
ms = flag_ms_zero_uvws(ms=ms)

return ms

Expand Down
6 changes: 2 additions & 4 deletions flint/prefect/common/imaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,10 @@


@task
def task_flag_ms_aoflagger(ms: FlagMS, container: Path, rounds: int = 1) -> FlagMS:
def task_flag_ms_aoflagger(ms: FlagMS, container: Path) -> FlagMS:
extracted_ms = ms.ms if isinstance(ms, ApplySolutions) else ms

extracted_ms = flag_ms_aoflagger(
ms=extracted_ms, container=container, rounds=rounds
)
extracted_ms = flag_ms_aoflagger(ms=extracted_ms, container=container)

return ms

Expand Down
3 changes: 1 addition & 2 deletions flint/prefect/flows/bandpass_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ def run_bandpass_stage(
flag_bandpass_mss = task_flag_ms_aoflagger.map(
ms=preprocess_bandpass_mss,
container=bandpass_options.flagger_container,
rounds=1,
)
calibrate_cmds = task_create_calibrate_cmd.map(
ms=flag_bandpass_mss,
Expand All @@ -179,7 +178,7 @@ def run_bandpass_stage(
container=bandpass_options.calibrate_container,
)
flag_bandpass_mss = task_flag_ms_aoflagger.map(
ms=apply_cmds, container=bandpass_options.flagger_container, rounds=1
ms=apply_cmds, container=bandpass_options.flagger_container
)
calibrate_cmds = task_create_calibrate_cmd.map(
ms=flag_bandpass_mss,
Expand Down
21 changes: 13 additions & 8 deletions flint/prefect/flows/continuum_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def process_science_fields(
container=field_options.calibrate_container,
)
flagged_mss = task_flag_ms_aoflagger.map(
ms=apply_solutions_cmds, container=field_options.flagger_container, rounds=1
ms=apply_solutions_cmds, container=field_options.flagger_container
)
column_rename_mss = task_rename_column_in_ms.map(
ms=flagged_mss,
Expand Down Expand Up @@ -241,7 +241,7 @@ def process_science_fields(
}
wsclean_rounds = {
1: {
"size": 8144,
"size": 7144,
"weight": "briggs -1.5",
"scale": "2.5arcsec",
"nmiter": 20,
Expand All @@ -255,7 +255,7 @@ def process_science_fields(
"multiscale_scales": (0, 15, 30, 40, 50, 60, 70, 120, 240, 480),
},
2: {
"size": 8144,
"size": 7144,
"weight": "briggs -1.5",
"scale": "2.5arcsec",
"multiscale": True,
Expand All @@ -266,12 +266,12 @@ def process_science_fields(
"channels_out": 36,
"deconvolution_channels": 6,
"fit_spectral_pol": 3,
"auto_mask": 5.0,
"auto_mask": 7.0,
"local_rms_window": 55,
"multiscale_scales": (0, 15, 30, 40, 50, 60, 70, 120, 240, 480),
},
3: {
"size": 8144,
"size": 7144,
"weight": "briggs -1.0",
"scale": "2.5arcsec",
"multiscale": True,
Expand All @@ -282,7 +282,7 @@ def process_science_fields(
"channels_out": 36,
"deconvolution_channels": 6,
"fit_spectral_pol": 3,
"auto_mask": 3.0,
"auto_mask": 6.0,
"local_rms_window": 55,
"multiscale_scales": (0, 15, 30, 40, 50, 60, 70, 120, 240, 480),
},
Expand All @@ -291,8 +291,13 @@ def process_science_fields(
for round in range(1, field_options.rounds + 1):
final_round = round == field_options.rounds

gain_cal_options = gain_cal_rounds.get(round, None)
wsclean_options = wsclean_rounds.get(round, None)
gain_cal_options = gain_cal_rounds.get(min((round, 3)), None)
wsclean_options = wsclean_rounds.get(min((round, 3)), None)

if round > 3:
wsclean_options["auto_mask"] = 5
wsclean_options["force_mask_rounds"] = 17
wsclean_options["local_rms_window"] = 30

cal_mss = task_gaincal_applycal_ms.map(
wsclean_cmd=wsclean_cmds,
Expand Down
Loading

0 comments on commit d20b18f

Please sign in to comment.