Skip to content

Commit

Permalink
Use double rather than float
Browse files Browse the repository at this point in the history
  • Loading branch information
scwatts committed Jun 1, 2018
1 parent b0df99c commit 7da0e76
Show file tree
Hide file tree
Showing 12 changed files with 109 additions and 109 deletions.
8 changes: 4 additions & 4 deletions src/bootstrap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@


// Get bootstrap samples
arma::Mat<double> get_bootstrap(OtuTable &otu_table, gsl_rng *p_rng) {
arma::Mat<double> bootstrap(otu_table.sample_number, otu_table.otu_number);
arma::Mat<float> get_bootstrap(OtuTable &otu_table, gsl_rng *p_rng) {
arma::Mat<float> bootstrap(otu_table.sample_number, otu_table.otu_number);
for (int j = 0; j < otu_table.otu_number; ++j) {
// Get a random list of sample indices equal to the number of samples
std::vector<arma::uword> indices(otu_table.sample_number);
Expand Down Expand Up @@ -33,7 +33,7 @@ void get_and_write_bootstraps(OtuTable &otu_table, unsigned int bootstrap_number
#pragma omp parallel for schedule(static, 1)
for (unsigned int i = 0; i < bootstrap_number; ++i) {
// Get the bootstrap
arma::Mat<double> bootstrap = get_bootstrap(otu_table, p_rng);
arma::Mat<float> bootstrap = get_bootstrap(otu_table, p_rng);

// Transpose matrix in place before writing out
printf("\tWriting out bootstrapped %i\n", i);
Expand All @@ -47,7 +47,7 @@ void get_and_write_bootstraps(OtuTable &otu_table, unsigned int bootstrap_number
}


void write_out_bootstrap_table(arma::Mat<double> &bootstrap, std::vector<std::string> otu_ids, std::string filepath) {
void write_out_bootstrap_table(arma::Mat<float> &bootstrap, std::vector<std::string> otu_ids, std::string filepath) {
// Get file handle
FILE *filehandle = fopen(filepath.c_str(), "w");

Expand Down
4 changes: 2 additions & 2 deletions src/bootstrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@


// Get single bootstrap for OTU table
arma::Mat<double> get_bootstrap(OtuTable &otu_table, gsl_rng *p_rng);
arma::Mat<float> get_bootstrap(OtuTable &otu_table, gsl_rng *p_rng);

// Get n bootstraps for OTU table
void get_and_write_bootstraps(OtuTable &otu_table, unsigned int bootstrap_number, std::string prefix, unsigned int threads, unsigned int seed);

// Write out a bootstrap count table
void write_out_bootstrap_table(arma::Mat<double> &bootstrap, std::vector<std::string> otu_ids, std::string filepath);
void write_out_bootstrap_table(arma::Mat<float> &bootstrap, std::vector<std::string> otu_ids, std::string filepath);


#endif
28 changes: 14 additions & 14 deletions src/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ void OtuTable::load_otu_file(std::string filename) {
std::string line;
std::string ele;
std::stringstream line_stream;
std::vector<double> temp_counts_vector;
std::vector<float> temp_counts_vector;
bool id;
// Open file stream
std::ifstream otu_file;
Expand Down Expand Up @@ -52,25 +52,25 @@ void OtuTable::load_otu_file(std::string filename) {
id = false;
continue;
}
// Add current element to OTU count after converting to double; some OTUs may be corrected and therefore a double
temp_counts_vector.push_back(std::stod(ele));
// Add current element to OTU count after converting to float; some OTUs may be corrected and therefore a float
temp_counts_vector.push_back(std::stof(ele));
}
++otu_number;
}
// Finally construct the OTU observation matrix and _move_ to struct
arma::Mat<double> temp_otu_matrix(temp_counts_vector);
arma::Mat<float> temp_otu_matrix(temp_counts_vector);
temp_otu_matrix.reshape(sample_number, otu_number);
counts = std::move(temp_otu_matrix);
}


// Load a correlation table from file
arma::Mat<double> load_correlation_file(std::string &filename) {
arma::Mat<float> load_correlation_file(std::string &filename) {
// Used to store strings from file prior to matrix construction and other variables
std::string line;
std::string ele;
std::stringstream line_stream;
std::vector<double> correlations_vector;
std::vector<float> correlations_vector;
bool id;

// Open file stream
Expand Down Expand Up @@ -99,19 +99,19 @@ arma::Mat<double> load_correlation_file(std::string &filename) {
id = false;
continue;
}
// Add current element to correlation mat after converting to double
correlations_vector.push_back(std::stod(ele));
// Add current element to correlation mat after converting to float
correlations_vector.push_back(std::stof(ele));
}
}

// Construct matrix and return it
arma::Mat<double> correlations(correlations_vector);
arma::Mat<float> correlations(correlations_vector);
correlations.reshape(otu_number, otu_number);
return correlations;
}


void write_out_square_otu_matrix(arma::Mat<double> &matrix, OtuTable &otu_table, std::string filename) {
void write_out_square_otu_matrix(arma::Mat<float> &matrix, OtuTable &otu_table, std::string filename) {
// Get stream handle
std::ofstream outfile;
outfile.open(filename);
Expand Down Expand Up @@ -167,12 +167,12 @@ int int_from_optarg(const char *optarg) {
float float_from_optarg(const char *optarg) {
// Check at most the first 8 characters are numerical
std::string optstring(optarg);
std::string string_double = optstring.substr(0, 8);
for (std::string::iterator it = string_double.begin(); it != string_double.end(); ++it) {
std::string string_float = optstring.substr(0, 8);
for (std::string::iterator it = string_float.begin(); it != string_float.end(); ++it) {
if (!isdigit(*it) && (*it) != '.') {
fprintf(stderr, "This doesn't look like a usable double: %s\n", optarg);
fprintf(stderr, "This doesn't look like a usable float: %s\n", optarg);
exit(1);
}
}
return std::atof(string_double.c_str());
return std::atof(string_float.c_str());
}
6 changes: 3 additions & 3 deletions src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
struct OtuTable {
std::vector<std::string> sample_names;
std::vector<std::string> otu_ids;
arma::Mat<double> counts;
arma::Mat<float> counts;
int otu_number = 0;
int sample_number = 0;

Expand All @@ -38,10 +38,10 @@ struct OtuTable {


// Load a correlation (or covariance table) from file
arma::Mat<double> load_correlation_file(std::string &filename);
arma::Mat<float> load_correlation_file(std::string &filename);

// Save an square OTU matrix (e.g. correlation matrix) to file
void write_out_square_otu_matrix(arma::Mat<double> &matrix, OtuTable &otu_table, std::string filename);
void write_out_square_otu_matrix(arma::Mat<float> &matrix, OtuTable &otu_table, std::string filename);

// Set up rng environment and return default rng
gsl_rng *get_default_rng_handle(unsigned int seed);
Expand Down
56 changes: 28 additions & 28 deletions src/fastspar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int main(int argc, char **argv) {
// Check that OTUs have variance in their counts
std::vector<int> invariant_otus;
for (int i = 0; i < otu_table.otu_number; ++i) {
arma::Col<double> unique_counts = arma::unique(otu_table.counts.col(i));
arma::Col<float> unique_counts = arma::unique(otu_table.counts.col(i));
if (unique_counts.n_elem == 1) {
invariant_otus.push_back(i);
}
Expand Down Expand Up @@ -121,8 +121,8 @@ FastSparIteration::FastSparIteration(const OtuTable *_otu_table, unsigned int _e
components_remaining = otu_table->otu_number;

// We also have to setup the mod matrix
std::vector<double> mod_diag(otu_table->otu_number, otu_table->otu_number - 2);
arma::Mat<double> _mod = arma::diagmat((arma::Col<double>) mod_diag);
std::vector<float> mod_diag(otu_table->otu_number, otu_table->otu_number - 2);
arma::Mat<float> _mod = arma::diagmat((arma::Col<float>) mod_diag);
mod = _mod + 1;
}

Expand Down Expand Up @@ -216,7 +216,7 @@ void FastSparIteration::estimate_component_fractions(gsl_rng *p_rng) {
// Estimate fractions
for(int i = 0; i < otu_table->sample_number; ++i) {
// Get arma row and add pseudo count (then convert to double vector for rng function)
arma::Row<double> row_pseudocount = otu_table->counts.row(i) + 1;
arma::Row<double> row_pseudocount = arma::conv_to<arma::Row<double>>::from(otu_table->counts.row(i) + 1);

// Draw from dirichlet dist, storing results in theta double array
double *theta = new double[row_size];
Expand All @@ -226,7 +226,7 @@ void FastSparIteration::estimate_component_fractions(gsl_rng *p_rng) {

// Create arma::Row from double[] and update fractions row
arma::Mat<double> estimated_fractions_row(theta, 1, otu_table->otu_number);
fraction_estimates.row(i) = estimated_fractions_row;
fraction_estimates.row(i) = arma::conv_to<arma::Row<float>>::from(estimated_fractions_row);

// Free dynamic memory
delete[] theta;
Expand All @@ -238,14 +238,14 @@ void FastSparIteration::estimate_component_fractions(gsl_rng *p_rng) {
void FastSparIteration::calculate_fraction_log_ratio_variance() {
// TODO: Test the amount of memory pre-computing the log matrix is consuming
// Log fraction matrix and initialise a zero-filled matrix (diagonals must be initialised)
arma::Mat<double> log_fraction_estimates = arma::log(fraction_estimates);
arma::Mat<double> temp_fraction_variance(fraction_estimates.n_cols, fraction_estimates.n_cols, arma::fill::zeros);
arma::Mat<float> log_fraction_estimates = arma::log(fraction_estimates);
arma::Mat<float> temp_fraction_variance(fraction_estimates.n_cols, fraction_estimates.n_cols, arma::fill::zeros);

// Calculate log-ratio variance for fraction estimates
for (unsigned int i = 0; i < fraction_estimates.n_cols - 1; ++i) {
for (unsigned int j = i + 1; j < fraction_estimates.n_cols; ++j) {
// Calculate variance of log fractions
double col_variance = arma::var(log_fraction_estimates.col(i) - log_fraction_estimates.col(j));
float col_variance = arma::var(log_fraction_estimates.col(i) - log_fraction_estimates.col(j));

// Add to matrix
temp_fraction_variance(i, j) = col_variance;
Expand All @@ -260,15 +260,15 @@ void FastSparIteration::calculate_fraction_log_ratio_variance() {
// Calculate the basis variance
void FastSparIteration::calculate_basis_variance() {
// NOTE: We make a copy of the variance matrix here to restrict modifications outside this function
arma::Mat<double> fraction_variance_munge = fraction_variance;
arma::Mat<float> fraction_variance_munge = fraction_variance;

// If any pairs have been excluded, set variance to zero
if (!excluded_pairs.empty()) {
fraction_variance_munge((arma::Col<arma::uword>)excluded_pairs).fill(0.0);
}

// Calculate the component variance
arma::Col<double> component_variance= arma::sum(fraction_variance_munge, 1);
arma::Col<float> component_variance= arma::sum(fraction_variance_munge, 1);

// Solve Ax = b where A is the mod matrix and b is the component variance
basis_variance = arma::solve(mod, component_variance, arma::solve_opts::fast);
Expand All @@ -281,16 +281,16 @@ void FastSparIteration::calculate_basis_variance() {
// Calculate the correlation and covariance
void FastSparIteration::calculate_correlation_and_covariance(){
// Initialise matrices and vectors
std::vector<double> basis_cor_diag(otu_table->otu_number, 1);
arma::Mat<double> temp_basis_correlation = arma::diagmat((arma::Col<double>) basis_cor_diag);
arma::Mat<double> temp_basis_covariance = arma::diagmat((arma::Col<double>) basis_variance);
std::vector<float> basis_cor_diag(otu_table->otu_number, 1);
arma::Mat<float> temp_basis_correlation = arma::diagmat((arma::Col<float>) basis_cor_diag);
arma::Mat<float> temp_basis_covariance = arma::diagmat((arma::Col<float>) basis_variance);

// Calculate correlation and covariance for each element add set in basis matrices
for (int i = 0; i < otu_table->otu_number - 1; ++i) {
for (int j = i + 1; j < otu_table->otu_number; ++j) {
// Calculate cor and cov
double basis_cov_el = 0.5 * (basis_variance(i) + basis_variance(j) - fraction_variance(i, j));
double basis_cor_el = basis_cov_el / std::sqrt(basis_variance(i)) / std::sqrt(basis_variance(j));
float basis_cov_el = 0.5 * (basis_variance(i) + basis_variance(j) - fraction_variance(i, j));
float basis_cor_el = basis_cov_el / std::sqrt(basis_variance(i)) / std::sqrt(basis_variance(j));

// Check if we got a valid correlation
if (abs(basis_cor_el) > 1) {
Expand Down Expand Up @@ -320,14 +320,14 @@ void FastSparIteration::calculate_correlation_and_covariance(){
// Find the highest correlation and exclude this pair if correlation is above threshold
void FastSparIteration::find_and_exclude_pairs(float threshold) {
// Set diagonal to zero as we're not interesting in excluding self-pairs and get absolute correlation value
arma::Mat<double> basis_correlation_abs = arma::abs(basis_correlation);
arma::Mat<float> basis_correlation_abs = arma::abs(basis_correlation);
basis_correlation_abs.diag().zeros();

// Set previously excluded correlations to zero
basis_correlation_abs((arma::Col<arma::uword>)excluded_pairs).fill(0.0);

// Get all elements with the max value
double max_correlate = basis_correlation_abs.max();
float max_correlate = basis_correlation_abs.max();
arma::Col<arma::uword> max_correlate_idx = arma::find(basis_correlation_abs == max_correlate);

// If max correlation is above a threshold, subtract one from the appropriate mod matrix positions
Expand Down Expand Up @@ -365,48 +365,48 @@ void FastSparIteration::find_and_exclude_pairs(float threshold) {
void FastSpar::calculate_median_correlation_and_covariance() {
// Get median of all i,j elements across the iterations for correlation
// Add correlation matrices to arma Cube so that we can get views of all i, j of each matrix
arma::Cube<double> correlation_cube(otu_table->otu_number, otu_table->otu_number, correlation_vector.size());
arma::Cube<float> correlation_cube(otu_table->otu_number, otu_table->otu_number, correlation_vector.size());
correlation_cube.fill(0.0);

// Fill cube with correlation matrix slices
int cube_slice = 0;
for (std::vector<arma::Mat<double>>::iterator it = correlation_vector.begin();
for (std::vector<arma::Mat<float>>::iterator it = correlation_vector.begin();
it != correlation_vector.end(); ++it) {
correlation_cube.slice(cube_slice) = *it;
++cube_slice;
}

// Get median value for each i, j element across all n iterations
arma::Mat<double> temp_median_correlation(otu_table->otu_number, otu_table->otu_number);
arma::Mat<float> temp_median_correlation(otu_table->otu_number, otu_table->otu_number);
for (int i = 0; i < otu_table->otu_number; ++i) {
for (int j = 0; j < otu_table->otu_number; ++j) {
arma::Row<double> r = correlation_cube.subcube(arma::span(i), arma::span(j), arma::span());
arma::Row<float> r = correlation_cube.subcube(arma::span(i), arma::span(j), arma::span());
temp_median_correlation(i, j) = arma::median(r);
}
}

// Get median for diagonal elements across iterations for covariance
// Add covariance diagonals to arma Mat so that we can get row views for all i, j elements
arma::Mat<double> covariance_diagonals(otu_table->otu_number, iterations);
arma::Mat<float> covariance_diagonals(otu_table->otu_number, iterations);

// Fill matrix will covariance diagonals
int matrix_column = 0;
for (std::vector<arma::Col<double>>::iterator it = covariance_vector.begin(); it != covariance_vector.end(); ++it) {
for (std::vector<arma::Col<float>>::iterator it = covariance_vector.begin(); it != covariance_vector.end(); ++it) {
covariance_diagonals.col(matrix_column) = *it;
++matrix_column;
}

// Get the median of each i, j element
arma::Col<double> median_covariance_diag = arma::median(covariance_diagonals, 1);
arma::Col<float> median_covariance_diag = arma::median(covariance_diagonals, 1);

// Split into coordinate meshed grid
arma::Mat<double> median_covariance_y(otu_table->otu_number, otu_table->otu_number);
arma::Mat<double> median_covariance_x(otu_table->otu_number, otu_table->otu_number);
arma::Mat<float> median_covariance_y(otu_table->otu_number, otu_table->otu_number);
arma::Mat<float> median_covariance_x(otu_table->otu_number, otu_table->otu_number);
median_covariance_y.each_col() = median_covariance_diag;
median_covariance_x.each_row() = arma::conv_to<arma::Row<double>>::from(median_covariance_diag);
median_covariance_x.each_row() = arma::conv_to<arma::Row<float>>::from(median_covariance_diag);

// Calculate final median covariance
arma::Mat<double> temp_median_covariance(otu_table->otu_number, otu_table->otu_number);
arma::Mat<float> temp_median_covariance(otu_table->otu_number, otu_table->otu_number);
temp_median_covariance = temp_median_correlation % arma::pow(median_covariance_x, 0.5) % arma::pow(median_covariance_y, 0.5);

// Move the temp matrices to FastSpar
Expand Down
20 changes: 10 additions & 10 deletions src/fastspar.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@ struct FastSpar {
const OtuTable *otu_table;

// List of each correlation and covariance matrix calculated during iterations
std::vector<arma::Mat<double>> correlation_vector;
std::vector<arma::Col<double>> covariance_vector;
std::vector<arma::Mat<float>> correlation_vector;
std::vector<arma::Col<float>> covariance_vector;

// The median correlation covariance of all iterations
arma::Mat<double> median_correlation;
arma::Mat<double> median_covariance;
arma::Mat<float> median_correlation;
arma::Mat<float> median_covariance;


// Construct FastSpar with a given otu_table and other parameters
Expand All @@ -59,22 +59,22 @@ struct FastSparIteration {
unsigned int exclusion_threshold;

// Estimated fractions of OTUs
arma::Mat<double> fraction_estimates;
arma::Mat<float> fraction_estimates;
// Variance of estimated OTU fractions
arma::Mat<double> fraction_variance;
arma::Mat<float> fraction_variance;

// List of highly OTU pairs excluded in this iteration and number of components remaining
std::vector<arma::uword> excluded_pairs;
unsigned int components_remaining;

// Modifier matrix (lhs in dgesv)
arma::Mat<double> mod;
arma::Mat<float> mod;
// Basis variance vector
arma::Col<double> basis_variance;
arma::Col<float> basis_variance;

// Correlation and covariance for this iteration
arma::Mat<double> basis_correlation;
arma::Mat<double> basis_covariance;
arma::Mat<float> basis_correlation;
arma::Mat<float> basis_covariance;


// Construct FastSparIterations with a given otu_table and other parameters
Expand Down
Loading

0 comments on commit 7da0e76

Please sign in to comment.