Skip to content

Commit

Permalink
Enforce stable sorting in np.sort and np.argsort
Browse files Browse the repository at this point in the history
  • Loading branch information
dachengx committed Nov 14, 2024
1 parent f1bb229 commit 03c6585
Show file tree
Hide file tree
Showing 17 changed files with 39 additions and 45 deletions.
4 changes: 2 additions & 2 deletions straxen/analyses/daq_waveforms.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import typing

import matplotlib.pyplot as plt
import numpy as np
import pandas
import pymongo
import strax
import straxen
import utilix

Expand Down Expand Up @@ -116,7 +116,7 @@ def group_by_daq(run_id, group_by: str):
daq_config = _get_daq_config(run_id)
labels = [_board_to_host_link(daq_config, label) for label in labels]
labels = np.array(labels)
order = np.argsort(labels)
order = strax.stable_argsort(labels)
return labels[order], idx[order]
else:
return _group_channels_by_index(cable_map, group_by=group_by)
4 changes: 2 additions & 2 deletions straxen/analyses/holoviews_waveform_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import pandas as pd

import strax
import straxen

straxen._BOKEH_X_RANGE = None
Expand Down Expand Up @@ -338,7 +338,7 @@ def hvdisp_plot_peak_waveforms(
import holoviews as hv

if show_largest is not None and len(peaks) > show_largest:
show_i = np.argsort(peaks["area"])[-show_largest::]
show_i = strax.stable_argsort(peaks["area"])[-show_largest::]
peaks = peaks[show_i]

curves = []
Expand Down
2 changes: 1 addition & 1 deletion straxen/analyses/waveform_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def plot_peaks(
plt.figure(figsize=figsize)
plt.axhline(0, c="k", alpha=0.2)

peaks = peaks[np.argsort(-peaks["area"])[:show_largest]]
peaks = peaks[strax.stable_argsort(-peaks["area"])[:show_largest]]
peaks = strax.sort_by_time(peaks)

for p in peaks:
Expand Down
6 changes: 3 additions & 3 deletions straxen/plugins/events/event_basics.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import strax
import numpy as np
import numba
import strax
import straxen


Expand Down Expand Up @@ -273,7 +273,7 @@ def fill_result_i(self, event, peaks):
largest_s2s, s2_idx = largest_s2s[0:2], s2_idx[0:2]

if self.force_main_before_alt:
s2_order = np.argsort(largest_s2s["time"])
s2_order = strax.stable_argsort(largest_s2s["time"])
largest_s2s = largest_s2s[s2_order]
s2_idx = s2_idx[s2_order]

Expand Down Expand Up @@ -364,7 +364,7 @@ def get_largest_sx_peaks(

selected_peaks = peaks[s_mask]
s_index = np.arange(len(peaks))[s_mask]
largest_peaks = np.argsort(selected_peaks["area"])[-number_of_peaks:][::-1]
largest_peaks = strax.stable_argsort(selected_peaks["area"])[-number_of_peaks:][::-1]
return selected_peaks[largest_peaks], s_index[largest_peaks]

# If only we could numbafy this... Unfortunatly we cannot.
Expand Down
8 changes: 4 additions & 4 deletions straxen/plugins/gps_syncing/gps_syncing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import strax
import straxen
import utilix
import numpy as np
import pandas as pd
import datetime
import utilix
import strax
import straxen

from straxen.plugins.aqmon_hits.aqmon_hits import AqmonChannels
from scipy.interpolate import interp1d
Expand Down Expand Up @@ -65,7 +65,7 @@ def load_gps_array(self):
nanoseconds unix time."""
gps_info = self.gps_times_from_runid(self.run_id)
gps_info["pulse_time"] = np.int64(gps_info["gps_sec"] * 1e9) + np.int64(gps_info["gps_ns"])
gps_array = np.sort(gps_info["pulse_time"])
gps_array = strax.stable_sort(gps_info["pulse_time"])
return gps_array

def gps_times_from_runid(self, run_id):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import strax
import numpy as np
import strax
import straxen

export, __all__ = strax.exporter()
Expand Down Expand Up @@ -47,7 +47,7 @@ def compute(self, peaks):
# to reduce datasize
new_len = int(len(peaks) / peaks_size * self.online_max_bytes)
idx = np.random.choice(np.arange(len(peaks)), replace=False, size=new_len)
data = peaks[np.sort(idx)]
data = peaks[strax.stable_sort(idx)]

else: # peaks_size <= self.max_bytes:
data = peaks
Expand Down
9 changes: 3 additions & 6 deletions straxen/plugins/merged_s2s/merged_s2s.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from typing import Tuple

import numpy as np
import numba
import strax
import straxen
from straxen.plugins.peaklets.peaklets import drop_data_field

import numpy as np
import numba


export, __all__ = strax.exporter()

Expand Down Expand Up @@ -171,7 +169,6 @@ def get_merge_instructions(
max_duration,
max_gap,
max_area,
sort_kind="mergesort",
):
"""
Finding the group of peaklets to merge. To do this start with the
Expand All @@ -189,7 +186,7 @@ def get_merge_instructions(
peaklet_start_index = np.arange(len(peaklet_starts))
peaklet_end_index = np.arange(len(peaklet_starts))

for gap_i in np.argsort(peaklet_gaps, kind=sort_kind):
for gap_i in strax.stable_argsort(peaklet_gaps):
start_idx = peaklet_start_index[gap_i]
inclusive_end_idx = peaklet_end_index[gap_i + 1]
sum_area = np.sum(areas[start_idx : inclusive_end_idx + 1])
Expand Down
7 changes: 3 additions & 4 deletions straxen/plugins/peaklets/peaklets.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from typing import Dict, Tuple, Union

import numba
import numpy as np
import strax
from immutabledict import immutabledict
import strax
from strax.processing.general import _touching_windows
from strax.dtypes import DIGITAL_SUM_WAVEFORM_CHANNEL
import straxen
Expand Down Expand Up @@ -321,7 +320,7 @@ def compute(self, records, start, end):
# (a) doing hitfinding yet again (or storing hits)
# (b) increase strax memory usage / max_messages,
# possibly due to its currently primitive scheduling.
hit_max_times_argsort = np.argsort(hitlets["max_time"])
hit_max_times_argsort = strax.stable_argsort(hitlets["max_time"])
sorted_hit_max_times = hitlets["max_time"][hit_max_times_argsort]
sorted_hit_channels = hitlets["channel"][hit_max_times_argsort]
peaklet_max_times = peaklets["time"] + np.argmax(peaklets["data"], axis=1) * peaklets["dt"]
Expand Down Expand Up @@ -416,7 +415,7 @@ def add_hit_features(hitlets, peaklets):
"""Create hits timing features."""
split_hits = strax.split_by_containment(hitlets, peaklets)
for peaklet, h_max in zip(peaklets, split_hits):
max_time_diff = np.diff(np.sort(h_max["max_time"]))
max_time_diff = np.diff(strax.stable_sort(h_max["max_time"]))
if len(max_time_diff) > 0:
peaklet["max_diff"] = max_time_diff.max()
peaklet["min_diff"] = max_time_diff.min()
Expand Down
4 changes: 2 additions & 2 deletions straxen/plugins/peaks/peak_ambience.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def infer_dtype(self):
return dtype

def compute(self, lone_hits, peaks):
argsort = np.argsort(peaks["center_time"], kind="mergesort")
_peaks = np.sort(peaks, order="center_time")
argsort = strax.stable_argsort(peaks["center_time"])
_peaks = strax.stable_sort(peaks, order="center_time")
result = np.zeros(len(peaks), self.dtype)
_quick_assign(argsort, result, self.compute_ambience(lone_hits, peaks, _peaks))
return result
Expand Down
7 changes: 4 additions & 3 deletions straxen/plugins/peaks/peak_nearest_triggering.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numba
import strax
import straxen

from .peak_ambience import _quick_assign
from ..events import Events

Expand Down Expand Up @@ -60,16 +61,16 @@ def get_window_size(self):
return 10 * self.shadow_time_window_backward

def compute(self, peaks):
argsort = np.argsort(peaks["center_time"], kind="mergesort")
_peaks = np.sort(peaks, order="center_time")
argsort = strax.stable_argsort(peaks["center_time"])
_peaks = strax.stable_sort(peaks, order="center_time")
result = np.zeros(len(peaks), self.dtype)
_quick_assign(argsort, result, self.compute_triggering(peaks, _peaks))
return result

def compute_triggering(self, peaks, current_peak):
# sort peaks by center_time,
# because later we will use center_time to find the nearest peak
_peaks = np.sort(peaks, order="center_time")
_peaks = strax.stable_sort(peaks, order="center_time")
# only looking at triggering peaks
if self.only_trigger_min_area:
_is_triggering = _peaks["area"] > self.trigger_min_area
Expand Down
5 changes: 2 additions & 3 deletions straxen/plugins/peaks/peak_per_event.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import strax
import numpy as np

import strax
import straxen

export, __all__ = strax.exporter()
Expand Down Expand Up @@ -39,7 +38,7 @@ def compute(self, events, peaks):
sp["center_time"] - event["s1_center_time"]
)
# Start of new part
sorted_indices_split_peaks_ind = np.argsort(split_peaks_ind)
sorted_indices_split_peaks_ind = strax.stable_argsort(split_peaks_ind)
mapping = {
val: events["event_number"][i]
for val, i in zip(
Expand Down
5 changes: 2 additions & 3 deletions straxen/plugins/peaks/peak_se_sensity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np
from numba import njit

import strax
import straxen

Expand Down Expand Up @@ -105,8 +104,8 @@ def compute_se_density(self, peaks, _peaks):

def compute(self, peaks):
# sort peaks by center_time
argsort = np.argsort(peaks["center_time"], kind="mergesort")
_peaks = np.sort(peaks, order="center_time")
argsort = strax.stable_argsort(peaks["center_time"])
_peaks = strax.stable_sort(peaks, order="center_time")

# prepare output
se_density = np.zeros(len(peaks))
Expand Down
7 changes: 4 additions & 3 deletions straxen/plugins/peaks/peak_shadow.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
import numba
from scipy.stats import halfcauchy
from .peak_ambience import distance_in_xy, _quick_assign
import strax
import straxen

from .peak_ambience import distance_in_xy, _quick_assign

export, __all__ = strax.exporter()


Expand Down Expand Up @@ -144,8 +145,8 @@ def shadowdtype(self):
return dtype

def compute(self, peaks):
argsort = np.argsort(peaks["center_time"], kind="mergesort")
_peaks = np.sort(peaks, order="center_time")
argsort = strax.stable_argsort(peaks["center_time"])
_peaks = strax.stable_sort(peaks, order="center_time")
result = np.zeros(len(peaks), self.dtype)
_quick_assign(argsort, result, self.compute_shadow(peaks, _peaks))
return result
Expand Down
2 changes: 1 addition & 1 deletion straxen/plugins/veto_intervals/veto_intervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def compute(self, aqmon_hits, start, end):

result = result[:vetos_seen]
result["veto_interval"] = result["endtime"] - result["time"]
sort = np.argsort(result["time"])
sort = strax.stable_argsort(result["time"])
result = result[sort]
return result

Expand Down
3 changes: 1 addition & 2 deletions straxen/scada.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import warnings

import urllib
import time
from datetime import datetime
Expand Down Expand Up @@ -206,7 +205,7 @@ def _get_and_check_start_end(self, run_id, start, end, time_selection_kwargs):
# User specified a valid context and run_id, so get the start
# and end time for our query:
if isinstance(run_id, (list, tuple)):
run_id = np.sort(run_id) # Do not trust the user's
run_id = strax.stable_sort(run_id) # Do not trust the user's
start, _ = self.context.to_absolute_time_range(run_id[0], **time_selection_kwargs)
_, end = self.context.to_absolute_time_range(run_id[-1], **time_selection_kwargs)
else:
Expand Down
5 changes: 2 additions & 3 deletions straxen/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import io
import os
from warnings import warn
from os import environ as os_environ
from typing import Tuple
import tarfile
from immutabledict import immutabledict
Expand Down Expand Up @@ -65,7 +64,7 @@ def is_installed(module):
@export
def _is_on_pytest():
"""Check if we are on a pytest."""
return "PYTEST_CURRENT_TEST" in os_environ
return "PYTEST_CURRENT_TEST" in os.environ


@export
Expand Down Expand Up @@ -146,7 +145,7 @@ def create_unique_intervals(size, time_range=(0, 40), allow_zero_length=True):


def _convert_to_interval(time_stamps, allow_zero_length):
time_stamps = np.sort(time_stamps)
time_stamps = strax.stable_sort(time_stamps)
intervals = np.zeros(len(time_stamps) // 2, strax.time_dt_fields)
intervals["dt"] = 1
intervals["time"] = time_stamps[::2]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_peaklet_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_filled_peaks(peak_length, data_length, n_widths):
strat.lists(strat.integers(min_value=0, max_value=10), min_size=8, max_size=8, unique=True),
)
def test_create_outside_peaks_region(time):
time = np.sort(time)
time = strax.stable_sort(time)
time_intervals = np.zeros(len(time) // 2, strax.time_dt_fields)
time_intervals["time"] = time[::2]
time_intervals["length"] = time[1::2] - time[::2]
Expand Down

0 comments on commit 03c6585

Please sign in to comment.