From 739ad30dcc92ec47c883cee3603de94224fa4432 Mon Sep 17 00:00:00 2001 From: Dacheng Xu Date: Sun, 17 Nov 2024 17:04:06 -0500 Subject: [PATCH] Raise error when peaks overlapping in `merge_peaks` (#927) * Simplify the required buffer length * Add a check of peaks overlapping --- strax/context.py | 2 +- strax/processing/peak_merging.py | 33 +++++++------------------------- strax/run_selection.py | 2 +- 3 files changed, 9 insertions(+), 28 deletions(-) diff --git a/strax/context.py b/strax/context.py index 08add1d1..4e46403c 100644 --- a/strax/context.py +++ b/strax/context.py @@ -25,7 +25,7 @@ RUN_DEFAULTS_KEY = "strax_defaults" TEMP_DATA_TYPE_PREFIX = "_temp_" -# use tqdm as loaded in utils (from tqdm.notebook when in a juypyter env) +# use tqdm as loaded in utils (from tqdm.notebook when in a jupyter env) tqdm = strax.utils.tqdm diff --git a/strax/processing/peak_merging.py b/strax/processing/peak_merging.py index 867814d1..c06c884f 100644 --- a/strax/processing/peak_merging.py +++ b/strax/processing/peak_merging.py @@ -25,6 +25,8 @@ def merge_peaks( """ assert len(start_merge_at) == len(end_merge_at) + if np.min(peaks["time"][1:] - strax.endtime(peaks)[:-1]) < 0: + raise ValueError("Peaks not disjoint! You have to rewrite this function to handle this.") new_peaks = np.zeros(len(start_merge_at), dtype=peaks.dtype) # Do the merging. Could numbafy this to optimize, probably... @@ -45,32 +47,11 @@ def merge_peaks( # re-zero relevant part of buffers (overkill? not sure if # this saves much time) - buffer[ - : min( - int( - ( - last_peak["time"] - - first_peak["time"] - + (last_peak["length"] * old_peaks["dt"].max()) - ) - / common_dt - ), - len(buffer), - ) - ] = 0 - buffer_top[ - : min( - int( - ( - last_peak["time"] - - first_peak["time"] - + (last_peak["length"] * old_peaks["dt"].max()) - ) - / common_dt - ), - len(buffer_top), - ) - ] = 0 + bl = last_peak["time"] - first_peak["time"] + bl += last_peak["length"] * old_peaks["dt"].max() + bl = min(int(bl / common_dt), max_buffer) + buffer[:bl] = 0 + buffer_top[:bl] = 0 for p in old_peaks: # Upsample the sum and top/bottom array waveforms into their buffers diff --git a/strax/run_selection.py b/strax/run_selection.py index 638bf517..810ff87e 100644 --- a/strax/run_selection.py +++ b/strax/run_selection.py @@ -12,7 +12,7 @@ import strax from strax import stable_argsort -# use tqdm as loaded in utils (from tqdm.notebook when in a juypyter env) +# use tqdm as loaded in utils (from tqdm.notebook when in a jupyter env) tqdm = strax.utils.tqdm export, __all__ = strax.exporter()