Skip to content

Commit

Permalink
Use max_time when calculating peaklets properties (#1459)
Browse files Browse the repository at this point in the history
* Use `max_time` when calculating peaklets properties

* Remove redundant code
  • Loading branch information
dachengx authored Oct 24, 2024
1 parent 4b315ae commit 522ac0f
Showing 1 changed file with 12 additions and 50 deletions.
62 changes: 12 additions & 50 deletions straxen/plugins/peaklets/peaklets.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,7 @@ def setup(self):
self.channel_range = self.channel_map["tpc"]

def compute(self, records, start, end):
r = records

hits = strax.find_hits(r, min_amplitude=self.hit_thresholds)
hits = strax.find_hits(records, min_amplitude=self.hit_thresholds)

# Remove hits in zero-gain channels
# they should not affect the clustering!
Expand Down Expand Up @@ -262,7 +260,6 @@ def compute(self, records, start, end):
hitlets["length"] = hitlets["right_integration"] - hitlets["left_integration"]

hitlets = strax.sort_by_time(hitlets)
hitlets_time = np.copy(hitlets["time"])
self.clip_peaklet_times(hitlets, start, end)
rlinks = strax.record_links(records)

Expand All @@ -271,7 +268,7 @@ def compute(self, records, start, end):
strax.sum_waveform(
peaklets,
hitlets,
r,
records,
rlinks,
self.to_pe,
n_top_channels=_n_top_pmts,
Expand All @@ -286,7 +283,7 @@ def compute(self, records, start, end):
peaklets = strax.split_peaks(
peaklets,
hitlets,
r,
records,
rlinks,
self.to_pe,
algorithm="natural_breaks",
Expand All @@ -302,15 +299,10 @@ def compute(self, records, start, end):
# Saturation correction using non-saturated channels
# similar method used in pax
# see https://github.com/XENON1T/pax/pull/712
# Cases when records is not writeable for unclear reason
# only see this when loading 1T test data
# more details on https://numpy.org/doc/stable/reference/generated/numpy.ndarray.flags.html
if not r["data"].flags.writeable:
r = r.copy()

if self.saturation_correction_on:
peak_list = peak_saturation_correction(
r,
records,
rlinks,
peaklets,
hitlets,
Expand All @@ -329,14 +321,8 @@ def compute(self, records, start, end):
# (a) doing hitfinding yet again (or storing hits)
# (b) increase strax memory usage / max_messages,
# possibly due to its currently primitive scheduling.
hitlet_time_shift = (hitlets["left"] - hitlets["left_integration"]) * hitlets["dt"]
hit_max_times = (
hitlets_time + hitlet_time_shift
) # add time shift again to get correct maximum
hit_max_times += hitlets["dt"] * hit_max_sample(records, hitlets)

hit_max_times_argsort = np.argsort(hit_max_times)
sorted_hit_max_times = hit_max_times[hit_max_times_argsort]
hit_max_times_argsort = np.argsort(hitlets["max_time"])
sorted_hit_max_times = hitlets["max_time"][hit_max_times_argsort]
sorted_hit_channels = hitlets["channel"][hit_max_times_argsort]
peaklet_max_times = peaklets["time"] + np.argmax(peaklets["data"], axis=1) * peaklets["dt"]
peaklets["tight_coincidence"] = get_tight_coin(
Expand All @@ -349,10 +335,10 @@ def compute(self, records, start, end):
)

# Add max and min time difference between apexes of hits
self.add_hit_features(hitlets, hit_max_times, peaklets)
self.add_hit_features(hitlets, peaklets)

if self.diagnose_sorting and len(r):
assert np.diff(r["time"]).min(initial=1) >= 0, "Records not sorted"
if self.diagnose_sorting and len(records):
assert np.diff(records["time"]).min(initial=1) >= 0, "Records not sorted"
assert np.diff(hitlets["time"]).min(initial=1) >= 0, "Hits/Hitlets not sorted"
assert np.all(
peaklets["time"][1:] >= strax.endtime(peaklets)[:-1]
Expand Down Expand Up @@ -426,22 +412,9 @@ def create_outside_peaks_region(peaklets, start, end):
return outside_peaks

@staticmethod
def add_hit_features(hitlets, hit_max_times, peaklets):
"""Create hits timing features.
:param hitlets_max: hitlets with only max height time.
:param peaklets: Peaklets for which intervals should be computed.
:return: array of peaklet_timing dtype.
"""
hits_w_max = np.zeros(
len(hitlets),
strax.merged_dtype([np.dtype([("max_time", np.int64)]), np.dtype(strax.time_fields)]),
)
hits_w_max["time"] = hitlets["time"]
hits_w_max["endtime"] = strax.endtime(hitlets)
hits_w_max["max_time"] = hit_max_times
split_hits = strax.split_by_containment(hits_w_max, peaklets)
def add_hit_features(hitlets, peaklets):
"""Create hits timing features."""
split_hits = strax.split_by_containment(hitlets, peaklets)
for peaklet, h_max in zip(peaklets, split_hits):
max_time_diff = np.diff(np.sort(h_max["max_time"]))
if len(max_time_diff) > 0:
Expand Down Expand Up @@ -707,14 +680,3 @@ def get_tight_coin(hit_max_times, hit_channel, peak_max_times, left, right, chan
n_coin_channel[p_i] = np.sum(channels_seen)

return n_coin_channel


@numba.njit(cache=True, nogil=True)
def hit_max_sample(records, hits):
"""Return the index of the maximum sample for hits."""
result = np.zeros(len(hits), dtype=np.int16)
for i, h in enumerate(hits):
r = records[h["record_i"]]
w = r["data"][h["left"] : h["right"]]
result[i] = np.argmax(w)
return result

0 comments on commit 522ac0f

Please sign in to comment.