Skip to content

Commit

Permalink
generalized index permutation function and NaNs can be given to multi…
Browse files Browse the repository at this point in the history
…linear now
  • Loading branch information
Kyle Carow authored and Kyle Carow committed Nov 27, 2023
1 parent 58945f8 commit 8379e4a
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 57 deletions.
1 change: 1 addition & 0 deletions rust/fastsim-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ lazy_static = "1.4.0"
regex = "1.7.1"
rayon = "1.7.0"
include_dir = "0.7.3"
itertools = "0.12.0"

[package.metadata]
include = [
Expand Down
2 changes: 1 addition & 1 deletion rust/fastsim-core/src/imports.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
pub(crate) use anyhow::{anyhow, bail, ensure, Context};
pub(crate) use bincode;
pub(crate) use log;
pub(crate) use ndarray::{array, concatenate, s, Array, Array1, Axis};
pub(crate) use ndarray::{array, s, Array, Array1, Axis};
pub(crate) use serde::{Deserialize, Serialize};
pub(crate) use std::cmp;
pub(crate) use std::ffi::OsStr;
Expand Down
111 changes: 55 additions & 56 deletions rust/fastsim-core/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
//! Module containing miscellaneous utility functions.

use itertools::Itertools;
use lazy_static::lazy_static;
use ndarray::prelude::*;
use ndarray::*;
use regex::Regex;
use std::collections::HashSet;

Expand Down Expand Up @@ -221,60 +222,56 @@ pub fn interpolate_vectors(
yl + dydx * (x - xl)
}

/// Get all possible indeces of an array of length 2 in `n` dimensions. Result will have shape (2<sup>n</sup>, n).
/// Helper function for [`multilinear`] interpolator.
/// Generate all permutations of indices for a given *N*-dimensional array shape
///
/// # Arguments:
/// * `n` - dimensionality
/// # Arguments
/// * `shape` - Reference to shape of the *N*-dimensional array returned by `ndarray::Array::shape()`
///
/// # Returns
/// A `Vec<Vec<usize>>` where each inner `Vec<usize>` is one permutation of indices
///
/// # Example:
/// # Example
/// ```rust
/// use fastsim_core::utils::get_binary_indeces;
/// assert_eq!(
/// get_binary_indeces(2),
/// vec![
/// vec![0, 0],
/// vec![0, 1],
/// vec![1, 0],
/// vec![1, 1],
/// ]
/// );
/// use fastsim_core::utils::get_index_permutations;
/// let shape = [3, 2, 2];
/// assert_eq!(
/// get_binary_indeces(3),
/// vec![
/// vec![0, 0, 0],
/// vec![0, 0, 1],
/// vec![0, 1, 0],
/// vec![0, 1, 1],
/// vec![1, 0, 0],
/// vec![1, 0, 1],
/// vec![1, 1, 0],
/// vec![1, 1, 1],
/// get_index_permutations(&shape),
/// [
/// [0, 0, 0],
/// [0, 0, 1],
/// [0, 1, 0],
/// [0, 1, 1],
/// [1, 0, 0],
/// [1, 0, 1],
/// [1, 1, 0],
/// [1, 1, 1],
/// [2, 0, 0],
/// [2, 0, 1],
/// [2, 1, 0],
/// [2, 1, 1],
/// ]
/// );
/// ```
pub fn get_binary_indeces(n: usize) -> Vec<Vec<usize>> {
let len = 2_usize.pow(n as u32);
let mut indeces = Vec::with_capacity(len);
for i in 0..len {
let mut index = Vec::with_capacity(n);
for j in (0..n).rev() {
index.push(((i >> j) & 1) as usize);
}
indeces.push(index);
}
indeces
pub fn get_index_permutations(shape: &[usize]) -> Vec<Vec<usize>> {
if shape.is_empty() {
return vec![vec![]];
}
shape
.iter()
.map(|&len| 0..len)
.multi_cartesian_product()
.collect()
}

/// Multilinear interpolation function, accepting any dimensionality *`N`*.
/// Multilinear interpolation function, accepting any dimensionality *N*.
///
/// Arguments
///
/// * `point`: interpolation point - specified by *`N`*-length array `&[x, y, z, ...]`
/// * `point`: interpolation point - specified by *N*-length array `&[x, y, z, ...]`
///
/// * `grid`: rectilinear grid points - *`N`*-length array of x, y, z, ... grid coordinate vectors
/// * `grid`: rectilinear grid points - *N*-length array of x, y, z, ... grid coordinate vectors
///
/// * `values`: *`N`*-dimensional `ndarray::ArrayD` containing values at grid points, can be created by calling `into_dyn()`
/// * `values`: *N*-dimensional [`ndarray::ArrayD`] containing values at grid points, can be created by calling [`Array::into_dyn()`]
///
pub fn multilinear(point: &[f64], grid: &[Vec<f64>], values: &ArrayD<f64>) -> anyhow::Result<f64> {
// Dimensionality
Expand All @@ -287,11 +284,7 @@ pub fn multilinear(point: &[f64], grid: &[Vec<f64>], values: &ArrayD<f64>) -> an
);
anyhow::ensure!(
grid.len() == n,
"Supplied `grid` must have same dimensionality as `values`: {grid:?} is not {n}-dimensional",
);
anyhow::ensure!(
!values.iter().any(|&x| x.is_nan()),
"Supplied `values` array cannot contain NaNs",
"Length of supplied `grid` must be same as `values` dimensionality: {grid:?} is not {n}-dimensional",
);
for i in 0..n {
// TODO: This ensure! could be removed if subsetting got rid of length 1 dimensions in `grid` and `points` as well
Expand Down Expand Up @@ -362,31 +355,37 @@ pub fn multilinear(point: &[f64], grid: &[Vec<f64>], values: &ArrayD<f64>) -> an
let mut interp_vals = values_view
.slice_each_axis(|ax| {
let lower = lower_idxs[ax.axis.0];
ndarray::Slice::from(lower..=lower + 1)
Slice::from(lower..=lower + 1)
})
.to_owned();
// Binary is handy as there are 2 surrounding values to index in each dimension: lower and upper
let mut binary_idxs = get_binary_indeces(n);
let mut index_permutations = get_index_permutations(&interp_vals.shape());
// This loop interpolates in each dimension sequentially
// each outer loop iteration the dimensionality reduces by 1
// `interp_vals` ends up as a 0-dimensional array containing only the final interpolated value
for dim in 0..n {
let diff = interp_diffs[dim];
let next_dim = n - 1 - dim;
let next_shape = vec![2; next_dim];
// Indeces used for saving results of this dimensions interpolation results
// assigned to `binary_idxs` at end of loop to be used for indexing in next iteration
let next_idxs = get_binary_indeces(next_dim);
let mut intermediate_arr = Array::default(vec![2; next_dim]);
// assigned to `index_permutations` at end of loop to be used for indexing in next iteration
let next_idxs = get_index_permutations(&next_shape);
let mut intermediate_arr = Array::default(next_shape);
for i in 0..next_idxs.len() {
// `next_idxs` is always half the length of `binary_idxs`
let l = binary_idxs[i].as_slice();
let u = binary_idxs[next_idxs.len() + i].as_slice();
// `next_idxs` is always half the length of `index_permutations`
let l = index_permutations[i].as_slice();
let u = index_permutations[next_idxs.len() + i].as_slice();
if dim == 0 {
anyhow::ensure!(
!interp_vals[l].is_nan() && !interp_vals[u].is_nan(),
"Surrounding value(s) cannot be NaN:\npoint = {point:?},\ngrid = {grid:?},\nvalues = {values:?}"
);
}
// This calculation happens 2^(n-1) times in the first iteration of the outer loop,
// 2^(n-2) times in the second iteration, etc.
intermediate_arr[next_idxs[i].as_slice()] =
interp_vals[l] * (1.0 - diff) + interp_vals[u] * diff;
}
binary_idxs = next_idxs;
index_permutations = next_idxs;
interp_vals = intermediate_arr;
}

Expand Down

0 comments on commit 8379e4a

Please sign in to comment.