Skip to content

Commit

Permalink
Avoid checking NaNs for Value_ types that don't have them.
Browse files Browse the repository at this point in the history
  • Loading branch information
LTLA committed Jun 27, 2024
1 parent 5b682e1 commit fb5467e
Show file tree
Hide file tree
Showing 7 changed files with 362 additions and 258 deletions.
48 changes: 28 additions & 20 deletions include/tatami_stats/grouped_sums.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,22 @@ void apply(bool row, const tatami::Matrix<Value_, Index_>* p, const Group_* grou
auto range = ext->fetch(xbuffer.data(), ibuffer.data());
std::fill(tmp.begin(), tmp.end(), static_cast<Output_>(0));

if (sopt.skip_nan) {
for (int j = 0; j < range.number; ++j) {
auto val = range.value[j];
if (!std::isnan(val)) {
tmp[group[range.index[j]]] += val;
internal::nanable_ifelse<Value_>(
sopt.skip_nan,
[&]() {
for (int j = 0; j < range.number; ++j) {
auto val = range.value[j];
if (!std::isnan(val)) {
tmp[group[range.index[j]]] += val;
}
}
},
[&]() {
for (int j = 0; j < range.number; ++j) {
tmp[group[range.index[j]]] += range.value[j];
}
}
} else {
for (int j = 0; j < range.number; ++j) {
tmp[group[range.index[j]]] += range.value[j];
}
}
);

for (size_t g = 0; g < num_groups; ++g) {
output[g][i + start] = tmp[g];
Expand Down Expand Up @@ -139,18 +143,22 @@ void apply(bool row, const tatami::Matrix<Value_, Index_>* p, const Group_* grou
auto ptr = ext->fetch(xbuffer.data());
std::fill(tmp.begin(), tmp.end(), static_cast<Output_>(0));

if (sopt.skip_nan) {
for (Index_ j = 0; j < otherdim; ++j) {
auto val = ptr[j];
if (!std::isnan(val)) {
tmp[group[j]] += val;
internal::nanable_ifelse<Value_>(
sopt.skip_nan,
[&]() {
for (Index_ j = 0; j < otherdim; ++j) {
auto val = ptr[j];
if (!std::isnan(val)) {
tmp[group[j]] += val;
}
}
},
[&]() {
for (Index_ j = 0; j < otherdim; ++j) {
tmp[group[j]] += ptr[j];
}
}
} else {
for (Index_ j = 0; j < otherdim; ++j) {
tmp[group[j]] += ptr[j];
}
}
);

for (size_t g = 0; g < num_groups; ++g) {
output[g][i + start] = tmp[g];
Expand Down
146 changes: 76 additions & 70 deletions include/tatami_stats/grouped_variances.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,42 +112,45 @@ void direct(
std::fill_n(output_means, num_groups, 0);
std::fill_n(output_variances, num_groups, 0);

if (skip_nan) {
std::fill_n(valid_group_size, num_groups, 0);

for (Index_ j = 0; j < num; ++j) {
auto x = ptr[j];
if (!std::isnan(x)) {
auto b = group[j];
output_means[b] += x;
++valid_group_size[b];
::tatami_stats::internal::nanable_ifelse<Value_>(
skip_nan,
[&]() {
std::fill_n(valid_group_size, num_groups, 0);

for (Index_ j = 0; j < num; ++j) {
auto x = ptr[j];
if (!std::isnan(x)) {
auto b = group[j];
output_means[b] += x;
++valid_group_size[b];
}
}
}
internal::finish_means(num_groups, valid_group_size, output_means);
internal::finish_means(num_groups, valid_group_size, output_means);

for (Index_ j = 0; j < num; ++j) {
auto x = ptr[j];
if (!std::isnan(x)) {
auto b = group[j];
auto delta = x - output_means[b];
output_variances[b] += delta * delta;
}
}
internal::finish_variances(num_groups, valid_group_size, output_variances);
},
[&]() {
for (Index_ j = 0; j < num; ++j) {
output_means[group[j]] += ptr[j];
}
internal::finish_means(num_groups, group_size, output_means);

for (Index_ j = 0; j < num; ++j) {
auto x = ptr[j];
if (!std::isnan(x)) {
for (Index_ j = 0; j < num; ++j) {
auto b = group[j];
auto delta = x - output_means[b];
auto delta = ptr[j] - output_means[b];
output_variances[b] += delta * delta;
}
internal::finish_variances(num_groups, group_size, output_variances);
}
internal::finish_variances(num_groups, valid_group_size, output_variances);

} else {
for (Index_ j = 0; j < num; ++j) {
output_means[group[j]] += ptr[j];
}
internal::finish_means(num_groups, group_size, output_means);

for (Index_ j = 0; j < num; ++j) {
auto b = group[j];
auto delta = ptr[j] - output_means[b];
output_variances[b] += delta * delta;
}
internal::finish_variances(num_groups, group_size, output_variances);
}
);
}

/**
Expand Down Expand Up @@ -200,52 +203,55 @@ void direct(
std::fill_n(output_nonzero, num_groups, 0);
std::fill_n(output_variances, num_groups, 0);

if (skip_nan) {
std::copy_n(group_size, num_groups, valid_group_size);

for (Index_ j = 0; j < num_nonzero; ++j) {
auto x = value[j];
auto b = group[index[j]];
if (!std::isnan(x)) {
output_means[b] += x;
++(output_nonzero[b]);
} else {
--(valid_group_size[b]);
::tatami_stats::internal::nanable_ifelse<Value_>(
skip_nan,
[&]() {
std::copy_n(group_size, num_groups, valid_group_size);

for (Index_ j = 0; j < num_nonzero; ++j) {
auto x = value[j];
auto b = group[index[j]];
if (!std::isnan(x)) {
output_means[b] += x;
++(output_nonzero[b]);
} else {
--(valid_group_size[b]);
}
}
}
internal::finish_means(num_groups, valid_group_size, output_means);
internal::finish_means(num_groups, valid_group_size, output_means);

for (Index_ j = 0; j < num_nonzero; ++j) {
auto x = value[j];
if (!std::isnan(x)) {
auto b = group[index[j]];
auto delta = x - output_means[b];
output_variances[b] += delta * delta;
}
}
for (size_t b = 0; b < num_groups; ++b) {
output_variances[b] += output_means[b] * output_means[b] * (valid_group_size[b] - output_nonzero[b]);
}
internal::finish_variances(num_groups, valid_group_size, output_variances);
},
[&]() {
for (Index_ j = 0; j < num_nonzero; ++j) {
auto b = group[index[j]];
output_means[b] += value[j];
++output_nonzero[b];
}
internal::finish_means(num_groups, group_size, output_means);

for (Index_ j = 0; j < num_nonzero; ++j) {
auto x = value[j];
if (!std::isnan(x)) {
for (Index_ j = 0; j < num_nonzero; ++j) {
auto b = group[index[j]];
auto delta = x - output_means[b];
auto delta = value[j] - output_means[b];
output_variances[b] += delta * delta;
}
for (size_t b = 0; b < num_groups; ++b) {
output_variances[b] += output_means[b] * output_means[b] * (group_size[b] - output_nonzero[b]);
}
internal::finish_variances(num_groups, group_size, output_variances);
}
for (size_t b = 0; b < num_groups; ++b) {
output_variances[b] += output_means[b] * output_means[b] * (valid_group_size[b] - output_nonzero[b]);
}
internal::finish_variances(num_groups, valid_group_size, output_variances);

} else {
for (Index_ j = 0; j < num_nonzero; ++j) {
auto b = group[index[j]];
output_means[b] += value[j];
++output_nonzero[b];
}
internal::finish_means(num_groups, group_size, output_means);

for (Index_ j = 0; j < num_nonzero; ++j) {
auto b = group[index[j]];
auto delta = value[j] - output_means[b];
output_variances[b] += delta * delta;
}
for (size_t b = 0; b < num_groups; ++b) {
output_variances[b] += output_means[b] * output_means[b] * (group_size[b] - output_nonzero[b]);
}
internal::finish_variances(num_groups, group_size, output_variances);
}
);
}

/**
Expand Down
31 changes: 20 additions & 11 deletions include/tatami_stats/medians.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#define TATAMI_STATS__MEDIANS_HPP

#include "tatami/tatami.hpp"
#include "utils.hpp"

#include <cmath>
#include <vector>
Expand Down Expand Up @@ -78,11 +79,15 @@ Index_ translocate_nans(Value_* ptr, Index_& num) {
*/
template<typename Output_ = double, typename Value_, typename Index_>
Output_ direct(Value_* ptr, Index_ num, bool skip_nan) {
if (skip_nan) {
auto lost = internal::translocate_nans(ptr, num);
ptr += lost;
num -= lost;
}
::tatami_stats::internal::nanable_ifelse<Value_>(
skip_nan,
[&]() {
auto lost = internal::translocate_nans(ptr, num);
ptr += lost;
num -= lost;
},
[&]() {}
);

if (num == 0) {
return std::numeric_limits<Output_>::quiet_NaN();
Expand Down Expand Up @@ -125,12 +130,16 @@ Output_ direct(Value_* value, Index_ num_nonzero, Index_ num_all, bool skip_nan)
return direct<Output_>(value, num_all, skip_nan);
}

if (skip_nan) {
auto lost = internal::translocate_nans(value, num_nonzero);
value += lost;
num_nonzero -= lost;
num_all -= lost;
}
::tatami_stats::internal::nanable_ifelse<Value_>(
skip_nan,
[&]() {
auto lost = internal::translocate_nans(value, num_nonzero);
value += lost;
num_nonzero -= lost;
num_all -= lost;
},
[&]() {}
);

// Is the number of non-zeros less than the number of zeros?
// If so, the median must be zero. Note that we calculate it
Expand Down
64 changes: 36 additions & 28 deletions include/tatami_stats/ranges.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,26 +94,30 @@ bool is_better(Output_ best, Value_ alt) {
*/
template<bool minimum_, typename Value_, typename Index_>
Value_ direct(const Value_* ptr, Index_ num, bool skip_nan) {
if (skip_nan) {
auto current = internal::choose_placeholder<minimum_, Value_>();
for (Index_ i = 0; i < num; ++i) {
auto val = ptr[i];
if (internal::is_better<minimum_>(current, val)) { // no need to explicitly handle NaNs, as any comparison with NaNs is always false.
current = val;
return ::tatami_stats::internal::nanable_ifelse_with_value<Value_>(
skip_nan,
[&]() -> Value_ {
auto current = internal::choose_placeholder<minimum_, Value_>();
for (Index_ i = 0; i < num; ++i) {
auto val = ptr[i];
if (internal::is_better<minimum_>(current, val)) { // no need to explicitly handle NaNs, as any comparison with NaNs is always false.
current = val;
}
}
return current;
},
[&]() -> Value_ {
if (num) {
if constexpr(minimum_) {
return *std::min_element(ptr, ptr + num);
} else {
return *std::max_element(ptr, ptr + num);
}
} else {
return internal::choose_placeholder<minimum_, Value_>();
}
}
return current;

} else if (num) {
if constexpr(minimum_) {
return *std::min_element(ptr, ptr + num);
} else {
return *std::max_element(ptr, ptr + num);
}

} else {
return internal::choose_placeholder<minimum_, Value_>();
}
);
}

/**
Expand Down Expand Up @@ -182,18 +186,22 @@ class RunningDense {
void add(const Value_* ptr) {
if (my_init) {
my_init = false;
if (my_skip_nan) {
for (Index_ i = 0; i < my_num; ++i, ++ptr) {
auto val = *ptr;
if (std::isnan(val)) {
my_store[i] = internal::choose_placeholder<minimum_, Value_>();
} else {
my_store[i] = val;
::tatami_stats::internal::nanable_ifelse<Value_>(
my_skip_nan,
[&]() {
for (Index_ i = 0; i < my_num; ++i, ++ptr) {
auto val = *ptr;
if (std::isnan(val)) {
my_store[i] = internal::choose_placeholder<minimum_, Value_>();
} else {
my_store[i] = val;
}
}
},
[&]() {
std::copy_n(ptr, my_num, my_store);
}
} else {
std::copy_n(ptr, my_num, my_store);
}
);

} else {
for (Index_ i = 0; i < my_num; ++i, ++ptr) {
Expand Down
Loading

0 comments on commit fb5467e

Please sign in to comment.