Skip to content

Commit

Permalink
speedup rolling min and max (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored May 7, 2024
1 parent e83483c commit 77a586c
Show file tree
Hide file tree
Showing 5 changed files with 102 additions and 48 deletions.
4 changes: 2 additions & 2 deletions include/expanding.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ inline void ExpandingStdTransform(const T *data, int n, T *out, T *agg) {

template <typename T> struct ExpandingMinTransform {
void operator()(const T *data, int n, T *out) {
RollingCompTransform(std::less<T>(), data, n, out, n, 1);
RollingMinTransform<T>(data, n, out, n, 1);
}
};

template <typename T> struct ExpandingMaxTransform {
void operator()(const T *data, int n, T *out) {
RollingCompTransform(std::greater<T>(), data, n, out, n, 1);
RollingMaxTransform<T>(data, n, out, n, 1);
}
};

Expand Down
124 changes: 90 additions & 34 deletions include/rolling.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,47 +67,103 @@ inline void RollingStdTransform(const T *data, int n, T *out, int window_size,
min_samples);
}

template <typename Func, typename T>
inline void RollingCompTransform(Func Comp, const T *data, int n, T *out,
int window_size, int min_samples) {
template <typename T, typename Comp> class SortedDeque {
public:
SortedDeque(int window_size, Comp comp = Comp())
: window_size_(window_size), comp_(comp) {
buffer_.reserve(window_size);
}
inline bool empty() const noexcept { return tail_ == -1; }
inline void push_back(int i, T x) noexcept {
if (tail_ == -1) {
head_ = 0;
tail_ = 0;
} else if (tail_ == window_size_ - 1) {
tail_ = 0;
} else {
++tail_;
}
buffer_[tail_] = {i, x};
}
inline void pop_back() noexcept {
if (head_ == tail_) {
head_ = 0;
tail_ = -1;
} else if (tail_ == 0) {
tail_ = window_size_ - 1;
} else {
--tail_;
}
}
inline void pop_front() noexcept {
if (head_ == tail_) {
head_ = 0;
tail_ = -1;
} else if (head_ == window_size_ - 1) {
head_ = 0;
} else {
++head_;
}
}
inline const std::pair<int, T> &front() const noexcept {
return buffer_[head_];
}
inline const std::pair<int, T> &back() const noexcept {
return buffer_[tail_];
}
void update(T x) noexcept {
while (!empty() && comp_(back().second, x)) {
pop_back();
}
if (!empty() && front().first <= i_) {
pop_front();
}
push_back(window_size_ + i_, x);
++i_;
}
T get() const noexcept { return front().second; }

private:
std::vector<std::pair<int, T>> buffer_;
int window_size_;
int head_ = 0;
int tail_ = -1;
int i_ = 0;
Comp comp_;
};

template <typename T, typename Comp>
inline void RollingCompTransform(const T *data, int n, T *out, int window_size,
int min_samples) {
int upper_limit = std::min(window_size, n);
T pivot = data[0];
SortedDeque<T, Comp> sdeque(window_size);
for (int i = 0; i < upper_limit; ++i) {
if (Comp(data[i], pivot)) {
pivot = data[i];
}
sdeque.update(data[i]);
if (i + 1 < min_samples) {
out[i] = std::numeric_limits<T>::quiet_NaN();
} else {
out[i] = pivot;
out[i] = sdeque.get();
}
}
for (int i = window_size; i < n; ++i) {
pivot = data[i];
for (int j = 0; j < window_size; ++j) {
if (Comp(data[i - j], pivot)) {
pivot = data[i - j];
}
}
out[i] = pivot;
for (int i = upper_limit; i < n; ++i) {
sdeque.update(data[i]);
out[i] = sdeque.get();
}
}

template <typename T> struct RollingMinTransform {
void operator()(const T *data, int n, T *out, int window_size,
int min_samples) {
RollingCompTransform(std::less<T>(), data, n, out, window_size,
min_samples);
}
};
template <typename T>
void RollingMinTransform(const T *data, int n, T *out, int window_size,
int min_samples) {
RollingCompTransform<T, std::greater_equal<T>>(data, n, out, window_size,
min_samples);
}

template <typename T> struct RollingMaxTransform {
void operator()(const T *data, int n, T *out, int window_size,
int min_samples) const {
RollingCompTransform(std::greater<T>(), data, n, out, window_size,
min_samples);
}
};
template <typename T>
void RollingMaxTransform(const T *data, int n, T *out, int window_size,
int min_samples) {
RollingCompTransform<T, std::less_equal<T>>(data, n, out, window_size,
min_samples);
}

template <typename T>
inline void RollingQuantileTransform(const T *data, int n, T *out,
Expand Down Expand Up @@ -171,15 +227,15 @@ template <typename T> struct SeasonalRollingStdTransform {
template <typename T> struct SeasonalRollingMinTransform {
void operator()(const T *data, int n, T *out, int season_length,
int window_size, int min_samples) {
SeasonalRollingTransform(RollingMinTransform<T>(), data, n, out,
SeasonalRollingTransform(RollingMinTransform<T>, data, n, out,
season_length, window_size, min_samples);
}
};

template <typename T> struct SeasonalRollingMaxTransform {
void operator()(const T *data, int n, T *out, int season_length,
int window_size, int min_samples) {
SeasonalRollingTransform(RollingMaxTransform<T>(), data, n, out,
SeasonalRollingTransform(RollingMaxTransform<T>, data, n, out,
season_length, window_size, min_samples);
}
};
Expand Down Expand Up @@ -226,15 +282,15 @@ template <typename T> struct RollingStdUpdate {
template <typename T> struct RollingMinUpdate {
void operator()(const T *data, int n, T *out, int window_size,
int min_samples) {
RollingUpdate(RollingMinTransform<T>(), data, n, out, window_size,
RollingUpdate(RollingMinTransform<T>, data, n, out, window_size,
min_samples);
}
};

template <typename T> struct RollingMaxUpdate {
void operator()(const T *data, int n, T *out, int window_size,
int min_samples) {
RollingUpdate(RollingMaxTransform<T>(), data, n, out, window_size,
RollingUpdate(RollingMaxTransform<T>, data, n, out, window_size,
min_samples);
}
};
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "coreforecast"
version = "0.0.8"
version = "0.0.9"
requires-python = ">=3.8"
dependencies = [
"importlib_resources ; python_version < '3.10'",
Expand Down
2 changes: 1 addition & 1 deletion python/coreforecast/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.0.8"
__version__ = "0.0.9"
18 changes: 8 additions & 10 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ int Float32_RollingStdTransform(float *data, int length, int window_size,
}
int Float32_RollingMinTransform(float *data, int length, int window_size,
int min_samples, float *out) {
RollingMinTransform<float>()(data, length, out, window_size, min_samples);
RollingMinTransform<float>(data, length, out, window_size, min_samples);
return 0;
}
int Float32_RollingMaxTransform(float *data, int length, int window_size,
int min_samples, float *out) {
RollingMaxTransform<float>()(data, length, out, window_size, min_samples);
RollingMaxTransform<float>(data, length, out, window_size, min_samples);
return 0;
}
int Float32_RollingQuantileTransform(float *data, int length, float p,
Expand Down Expand Up @@ -215,16 +215,14 @@ int GroupedArrayFloat32_RollingMinTransform(GroupedArrayHandle handle, int lag,
int window_size, int min_samples,
float *out) {
auto ga = reinterpret_cast<GroupedArray<float> *>(handle);
ga->Transform(RollingMinTransform<float>(), lag, out, window_size,
min_samples);
ga->Transform(RollingMinTransform<float>, lag, out, window_size, min_samples);
return 0;
}
int GroupedArrayFloat32_RollingMaxTransform(GroupedArrayHandle handle, int lag,
int window_size, int min_samples,
float *out) {
auto ga = reinterpret_cast<GroupedArray<float> *>(handle);
ga->Transform(RollingMaxTransform<float>(), lag, out, window_size,
min_samples);
ga->Transform(RollingMaxTransform<float>, lag, out, window_size, min_samples);
return 0;
}
int GroupedArrayFloat32_RollingQuantileTransform(GroupedArrayHandle handle,
Expand Down Expand Up @@ -542,12 +540,12 @@ int Float64_RollingStdTransform(double *data, int length, int window_size,
}
int Float64_RollingMinTransform(double *data, int length, int window_size,
int min_samples, double *out) {
RollingMinTransform<double>()(data, length, out, window_size, min_samples);
RollingMinTransform<double>(data, length, out, window_size, min_samples);
return 0;
}
int Float64_RollingMaxTransform(double *data, int length, int window_size,
int min_samples, double *out) {
RollingMaxTransform<double>()(data, length, out, window_size, min_samples);
RollingMaxTransform<double>(data, length, out, window_size, min_samples);
return 0;
}
int Float64_RollingQuantileTransform(double *data, int length, double p,
Expand Down Expand Up @@ -744,15 +742,15 @@ int GroupedArrayFloat64_RollingMinTransform(GroupedArrayHandle handle, int lag,
int window_size, int min_samples,
double *out) {
auto ga = reinterpret_cast<GroupedArray<double> *>(handle);
ga->Transform(RollingMinTransform<double>(), lag, out, window_size,
ga->Transform(RollingMinTransform<double>, lag, out, window_size,
min_samples);
return 0;
}
int GroupedArrayFloat64_RollingMaxTransform(GroupedArrayHandle handle, int lag,
int window_size, int min_samples,
double *out) {
auto ga = reinterpret_cast<GroupedArray<double> *>(handle);
ga->Transform(RollingMaxTransform<double>(), lag, out, window_size,
ga->Transform(RollingMaxTransform<double>, lag, out, window_size,
min_samples);
return 0;
}
Expand Down

0 comments on commit 77a586c

Please sign in to comment.