Skip to content

Commit

Permalink
roi metric sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
sronilsson committed Sep 10, 2024
1 parent 96c2031 commit 8b2c78e
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 10 deletions.
6 changes: 6 additions & 0 deletions docs/simba.gpu_helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,11 @@ Circular statistics
Image
---------------------
.. automodule:: simba.data_processors.cuda.image
:members:
:show-inheritance:

Time-series statistics
---------------------
.. automodule:: simba.data_processors.cuda.timeseries
:members:
:show-inheritance:
117 changes: 117 additions & 0 deletions simba/data_processors/cuda/timeseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import numpy as np
from typing import Optional
from time import perf_counter
from simba.utils.enums import Formats
from simba.utils.checks import check_valid_array, check_float
from numba import cuda
from simba.data_processors.cuda.utils import _cuda_mean, _cuda_std


THREADS_PER_BLOCK = 1024

@cuda.jit(device=True)
def _count_at_threshold(x: np.ndarray, inverse: int, threshold: float):
results = 0
for i in range(x.shape[0]):
if inverse[0] == 0:
if x[i] > threshold[0]:
results += 1
else:
if x[i] < threshold[0]:
results += 1
return results

@cuda.jit()
def _sliding_crossings_kernal(data, time, threshold, inverse, results):
r = cuda.grid(1)
l = int(r - time[0])
if r > data.shape[0] or r < 0:
return
elif l > data.shape[0] or l < 0:
return
else:
sample = data[l:r]
results[r-1] = _count_at_threshold(sample, inverse, threshold)

def sliding_threshold(data: np.ndarray, time_window: float, sample_rate: float, value: float, inverse: Optional[bool] = False) -> np.ndarray:
"""
Compute the count of observations above or below threshold crossings over a sliding window using GPU acceleration.
:param np.ndarray data: Input data array.
:param float time_window: Size of the sliding window in seconds.
:param float sample_rate: Number of samples per second in the data.
:param float value: Threshold value.
:param Optional[bool] inverse: If False, counts values above the threshold. If True, counts values below.
:return: Array containing count of threshold crossings per window.
:rtype: np.ndarray
"""

check_float(name='sample_rate', value=sample_rate, min_value=10e-6)
check_float(name='sample_rate', value=sample_rate, min_value=10e-6)
check_float(name='time_window', value=time_window, min_value=10e-6)
check_valid_array(data=data, source=sliding_threshold.__name__, accepted_ndims=(1,), accepted_dtypes=Formats.NUMERIC_DTYPES.value)

data_dev = cuda.to_device(data)
time_window_frames = np.array([np.ceil(time_window * sample_rate)])
time_window_frames_dev = cuda.to_device(time_window_frames)
value = np.array([value])
invert = np.array([0])
if inverse: invert[0] = 1
value_dev = cuda.to_device(value)
inverse_dev = cuda.to_device(invert)
bpg = (data.shape[0] + (THREADS_PER_BLOCK - 1)) // THREADS_PER_BLOCK
results = cuda.device_array(shape=(data.shape), dtype=np.int32)
_sliding_crossings_kernal[bpg, THREADS_PER_BLOCK](data_dev, time_window_frames_dev, value_dev, inverse_dev, results)
return results.copy_to_host()

@cuda.jit()
def _sliding_percent_beyond_n_std_kernel(data, time, std_n, results):
r = cuda.grid(1)
l = int(r - time[0])
if r > data.shape[0] or r < 0:
return
elif l > data.shape[0] or l < 0:
return
else:
sample = data[l:r]
m = _cuda_mean(sample)
std_val = _cuda_std(sample, m) * std_n[0]
cnt = 0
for i in range(sample.shape[0]):

if (sample[i] > (m + std_val)) or (sample[i] < (m - std_val)):
print(sample[i], m + std_val)
cnt += 1
results[r-1] = cnt

def sliding_percent_beyond_n_std(data: np.ndarray, time_window: float, sample_rate: float, value: float) -> np.ndarray:
"""
Computes the percentage of points in each sliding window of `data` that fall beyond
`n` standard deviations from the mean of that window.
This function uses GPU acceleration via CUDA to efficiently compute the result over large datasets.
:param np.ndarray data: The input 1D data array for which the sliding window computation is to be performed.
:param float time_window: The length of the time window in seconds.
:param float sample_rate: The sample rate of the data in Hz (samples per second).
:param float value: The number of standard deviations beyond which to count data points.
:return: An array containing the count of data points beyond `n` standard deviations for each window.
:rtype: np.ndarray
:example:
>>> data = np.random.randint(0, 100, (100,))
>>> results = sliding_percent_beyond_n_std(data=data, time_window=1, sample_rate=10, value=2)
"""

data_dev = cuda.to_device(data)
time_window_frames = np.array([np.ceil(time_window * sample_rate)])
time_window_frames_dev = cuda.to_device(time_window_frames)
value = np.array([value])
value_dev = cuda.to_device(value)
bpg = (data.shape[0] + (THREADS_PER_BLOCK - 1)) // THREADS_PER_BLOCK
results = cuda.device_array(shape=(data.shape), dtype=np.int32)
_sliding_percent_beyond_n_std_kernel[bpg, THREADS_PER_BLOCK](data_dev, time_window_frames_dev, value_dev, results)
return results.copy_to_host()



49 changes: 49 additions & 0 deletions simba/data_processors/cuda/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import numpy as np
from numba import cuda
import math

@cuda.jit(device=True)
def _cuda_sum(x: np.ndarray):
s = 0
for i in range(x.shape[0]):
s += x[i]
return s


@cuda.jit(device=True)
def _cuda_sin(x, t):
for i in range(x.shape[0]):
v = math.sin(x[i])
t[i] = v
return t

@cuda.jit(device=True)
def _cuda_cos(x, t):
for i in range(x.shape[0]):
v = math.cos(x[i])
t[i] = v
return t

@cuda.jit(device=True)
def _cuda_std(x: np.ndarray, x_hat: float):
std = 0
for i in range(x.shape[0]):
std += (x[0] - x_hat) ** 2
return std

@cuda.jit(device=True)
def _rad2deg(x):
return x * (180/math.pi)

@cuda.jit(device=True)
def _cross_test(x, y, x1, y1, x2, y2):
cross = (x - x1) * (y2 - y1) - (y - y1) * (x2 - x1)
return cross < 0


@cuda.jit(device=True)
def _cuda_mean(x):
s = 0
for i in range(x.shape[0]):
s += x[i]
return s / x.shape[0]
6 changes: 3 additions & 3 deletions simba/feature_extractors/feature_extractor_user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from simba.mixins.config_reader import ConfigReader
from simba.mixins.feature_extraction_mixin import FeatureExtractionMixin
from simba.utils.checks import check_str
from simba.utils.errors import MissingColumnsError
from simba.utils.checks import check_str, check_float
from simba.utils.errors import MissingColumnsError, ParametersFileError
from simba.utils.printing import SimbaTimer, stdout_success
from simba.utils.read_write import get_fn_ext, read_df, write_df

Expand Down Expand Up @@ -192,5 +192,5 @@ def run(self):
stdout_success(f"Feature extraction complete for {str(len(self.files_found))} video(s). Results are saved inside the {self.features_dir} directory", elapsed_time=self.timer.elapsed_time_str,)


# test = UserDefinedFeatureExtractor(config_path=r"C:\troubleshooting\open_field_below\project_folder\project_config.ini")
# test = UserDefinedFeatureExtractor(config_path=r"C:\troubleshooting\two_black_animals_14bp\project_folder\project_config.ini")
# test.run()
12 changes: 8 additions & 4 deletions simba/mixins/config_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import json
import logging
import logging.config
import math
import os
import shutil
from ast import literal_eval
Expand Down Expand Up @@ -760,8 +761,7 @@ def read_video_info_csv(self, file_path: Union[str, os.PathLike]) -> pd.DataFram
)
return info_df

def read_video_info(
self, video_name: str, raise_error: Optional[bool] = True) -> Tuple[pd.DataFrame, float, float]:
def read_video_info(self, video_name: str, raise_error: Optional[bool] = True) -> Tuple[pd.DataFrame, float, float]:
"""
Helper to read the meta-data (pixels per mm, resolution, fps) from the video_info.csv for a single input file.
Expand All @@ -778,13 +778,13 @@ def read_video_info(
]
if len(video_settings) > 1:
raise DuplicationError(
msg=f"SimBA found multiple rows in the project_folder/logs/video_info.csv named {str(video_name)}. Please make sure that each video name is represented ONCE in the video_info.csv",
msg=f"SimBA found multiple rows in the project_folder/logs/video_info.csv named {video_name}. Please make sure that each video name is represented ONCE in the {self.video_info_path} file",
source=self.__class__.__name__,
)
elif len(video_settings) < 1:
if raise_error:
raise ParametersFileError(
msg=f"SimBA could not find {str(video_name)} in the video_info.csv file. Make sure all videos analyzed are represented in the project_folder/logs/video_info.csv file.",
msg=f"SimBA could not find {str(video_name)} in the video_info.csv file. Make sure all videos analyzed are represented in the {self.video_info_path} file.",
source=self.__class__.__name__,
)
else:
Expand All @@ -793,6 +793,10 @@ def read_video_info(
try:
px_per_mm = float(video_settings["pixels/mm"])
fps = float(video_settings["fps"])
if math.isnan(px_per_mm):
raise ParametersFileError(msg=f'Pixels per millimeter for video {video_name} in the {self.video_info_path} file is not a valid number.')
if math.isnan(fps):
raise ParametersFileError(msg=f'The FPS for video {video_name} in the {self.video_info_path} file is not a valid number.')
return video_settings, px_per_mm, fps
except TypeError:
raise ParametersFileError(
Expand Down
4 changes: 1 addition & 3 deletions simba/mixins/timeseries_features_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,9 +388,7 @@ def percent_beyond_n_std(data: np.ndarray, n: float) -> float:

@staticmethod
@njit("(float64[:], float64, float64[:], int64,)", cache=True, fastmath=True)
def sliding_percent_beyond_n_std(
data: np.ndarray, n: float, window_sizes: np.ndarray, sample_rate: int
) -> np.ndarray:
def sliding_percent_beyond_n_std(data: np.ndarray, n: float, window_sizes: np.ndarray, sample_rate: int) -> np.ndarray:
"""
Computed the percentage of data points that exceed 'n' standard deviations from the mean for each position in
the time series using various window sizes. It returns a 2D array where each row corresponds to a position in the time series,
Expand Down

0 comments on commit 8b2c78e

Please sign in to comment.