diff --git a/src/bootstrap.cpp b/src/bootstrap.cpp index 4155ebd..149b6d9 100644 --- a/src/bootstrap.cpp +++ b/src/bootstrap.cpp @@ -2,8 +2,8 @@ // Get bootstrap samples -arma::Mat get_bootstrap(OtuTable &otu_table, gsl_rng *p_rng) { - arma::Mat bootstrap(otu_table.sample_number, otu_table.otu_number); +arma::Mat get_bootstrap(OtuTable &otu_table, gsl_rng *p_rng) { + arma::Mat bootstrap(otu_table.sample_number, otu_table.otu_number); for (int j = 0; j < otu_table.otu_number; ++j) { // Get a random list of sample indices equal to the number of samples std::vector indices(otu_table.sample_number); @@ -33,7 +33,7 @@ void get_and_write_bootstraps(OtuTable &otu_table, unsigned int bootstrap_number #pragma omp parallel for schedule(static, 1) for (unsigned int i = 0; i < bootstrap_number; ++i) { // Get the bootstrap - arma::Mat bootstrap = get_bootstrap(otu_table, p_rng); + arma::Mat bootstrap = get_bootstrap(otu_table, p_rng); // Transpose matrix in place before writing out printf("\tWriting out bootstrapped %i\n", i); @@ -47,7 +47,7 @@ void get_and_write_bootstraps(OtuTable &otu_table, unsigned int bootstrap_number } -void write_out_bootstrap_table(arma::Mat &bootstrap, std::vector otu_ids, std::string filepath) { +void write_out_bootstrap_table(arma::Mat &bootstrap, std::vector otu_ids, std::string filepath) { // Get file handle FILE *filehandle = fopen(filepath.c_str(), "w"); diff --git a/src/bootstrap.h b/src/bootstrap.h index f5f23ba..e4c9e5b 100644 --- a/src/bootstrap.h +++ b/src/bootstrap.h @@ -12,13 +12,13 @@ // Get single bootstrap for OTU table -arma::Mat get_bootstrap(OtuTable &otu_table, gsl_rng *p_rng); +arma::Mat get_bootstrap(OtuTable &otu_table, gsl_rng *p_rng); // Get n bootstraps for OTU table void get_and_write_bootstraps(OtuTable &otu_table, unsigned int bootstrap_number, std::string prefix, unsigned int threads, unsigned int seed); // Write out a bootstrap count table -void write_out_bootstrap_table(arma::Mat &bootstrap, std::vector otu_ids, std::string filepath); +void write_out_bootstrap_table(arma::Mat &bootstrap, std::vector otu_ids, std::string filepath); #endif diff --git a/src/common.cpp b/src/common.cpp index 1d2f06f..995cbd1 100644 --- a/src/common.cpp +++ b/src/common.cpp @@ -11,7 +11,7 @@ void OtuTable::load_otu_file(std::string filename) { std::string line; std::string ele; std::stringstream line_stream; - std::vector temp_counts_vector; + std::vector temp_counts_vector; bool id; // Open file stream std::ifstream otu_file; @@ -52,25 +52,25 @@ void OtuTable::load_otu_file(std::string filename) { id = false; continue; } - // Add current element to OTU count after converting to double; some OTUs may be corrected and therefore a double - temp_counts_vector.push_back(std::stod(ele)); + // Add current element to OTU count after converting to float; some OTUs may be corrected and therefore a float + temp_counts_vector.push_back(std::stof(ele)); } ++otu_number; } // Finally construct the OTU observation matrix and _move_ to struct - arma::Mat temp_otu_matrix(temp_counts_vector); + arma::Mat temp_otu_matrix(temp_counts_vector); temp_otu_matrix.reshape(sample_number, otu_number); counts = std::move(temp_otu_matrix); } // Load a correlation table from file -arma::Mat load_correlation_file(std::string &filename) { +arma::Mat load_correlation_file(std::string &filename) { // Used to store strings from file prior to matrix construction and other variables std::string line; std::string ele; std::stringstream line_stream; - std::vector correlations_vector; + std::vector correlations_vector; bool id; // Open file stream @@ -99,19 +99,19 @@ arma::Mat load_correlation_file(std::string &filename) { id = false; continue; } - // Add current element to correlation mat after converting to double - correlations_vector.push_back(std::stod(ele)); + // Add current element to correlation mat after converting to float + correlations_vector.push_back(std::stof(ele)); } } // Construct matrix and return it - arma::Mat correlations(correlations_vector); + arma::Mat correlations(correlations_vector); correlations.reshape(otu_number, otu_number); return correlations; } -void write_out_square_otu_matrix(arma::Mat &matrix, OtuTable &otu_table, std::string filename) { +void write_out_square_otu_matrix(arma::Mat &matrix, OtuTable &otu_table, std::string filename) { // Get stream handle std::ofstream outfile; outfile.open(filename); @@ -167,12 +167,12 @@ int int_from_optarg(const char *optarg) { float float_from_optarg(const char *optarg) { // Check at most the first 8 characters are numerical std::string optstring(optarg); - std::string string_double = optstring.substr(0, 8); - for (std::string::iterator it = string_double.begin(); it != string_double.end(); ++it) { + std::string string_float = optstring.substr(0, 8); + for (std::string::iterator it = string_float.begin(); it != string_float.end(); ++it) { if (!isdigit(*it) && (*it) != '.') { - fprintf(stderr, "This doesn't look like a usable double: %s\n", optarg); + fprintf(stderr, "This doesn't look like a usable float: %s\n", optarg); exit(1); } } - return std::atof(string_double.c_str()); + return std::atof(string_float.c_str()); } diff --git a/src/common.h b/src/common.h index f9f728d..1a4d2f1 100644 --- a/src/common.h +++ b/src/common.h @@ -28,7 +28,7 @@ struct OtuTable { std::vector sample_names; std::vector otu_ids; - arma::Mat counts; + arma::Mat counts; int otu_number = 0; int sample_number = 0; @@ -38,10 +38,10 @@ struct OtuTable { // Load a correlation (or covariance table) from file -arma::Mat load_correlation_file(std::string &filename); +arma::Mat load_correlation_file(std::string &filename); // Save an square OTU matrix (e.g. correlation matrix) to file -void write_out_square_otu_matrix(arma::Mat &matrix, OtuTable &otu_table, std::string filename); +void write_out_square_otu_matrix(arma::Mat &matrix, OtuTable &otu_table, std::string filename); // Set up rng environment and return default rng gsl_rng *get_default_rng_handle(unsigned int seed); diff --git a/src/fastspar.cpp b/src/fastspar.cpp index b940a4a..62f02b7 100644 --- a/src/fastspar.cpp +++ b/src/fastspar.cpp @@ -21,7 +21,7 @@ int main(int argc, char **argv) { // Check that OTUs have variance in their counts std::vector invariant_otus; for (int i = 0; i < otu_table.otu_number; ++i) { - arma::Col unique_counts = arma::unique(otu_table.counts.col(i)); + arma::Col unique_counts = arma::unique(otu_table.counts.col(i)); if (unique_counts.n_elem == 1) { invariant_otus.push_back(i); } @@ -121,8 +121,8 @@ FastSparIteration::FastSparIteration(const OtuTable *_otu_table, unsigned int _e components_remaining = otu_table->otu_number; // We also have to setup the mod matrix - std::vector mod_diag(otu_table->otu_number, otu_table->otu_number - 2); - arma::Mat _mod = arma::diagmat((arma::Col) mod_diag); + std::vector mod_diag(otu_table->otu_number, otu_table->otu_number - 2); + arma::Mat _mod = arma::diagmat((arma::Col) mod_diag); mod = _mod + 1; } @@ -216,7 +216,7 @@ void FastSparIteration::estimate_component_fractions(gsl_rng *p_rng) { // Estimate fractions for(int i = 0; i < otu_table->sample_number; ++i) { // Get arma row and add pseudo count (then convert to double vector for rng function) - arma::Row row_pseudocount = otu_table->counts.row(i) + 1; + arma::Row row_pseudocount = arma::conv_to>::from(otu_table->counts.row(i) + 1); // Draw from dirichlet dist, storing results in theta double array double *theta = new double[row_size]; @@ -226,7 +226,7 @@ void FastSparIteration::estimate_component_fractions(gsl_rng *p_rng) { // Create arma::Row from double[] and update fractions row arma::Mat estimated_fractions_row(theta, 1, otu_table->otu_number); - fraction_estimates.row(i) = estimated_fractions_row; + fraction_estimates.row(i) = arma::conv_to>::from(estimated_fractions_row); // Free dynamic memory delete[] theta; @@ -238,14 +238,14 @@ void FastSparIteration::estimate_component_fractions(gsl_rng *p_rng) { void FastSparIteration::calculate_fraction_log_ratio_variance() { // TODO: Test the amount of memory pre-computing the log matrix is consuming // Log fraction matrix and initialise a zero-filled matrix (diagonals must be initialised) - arma::Mat log_fraction_estimates = arma::log(fraction_estimates); - arma::Mat temp_fraction_variance(fraction_estimates.n_cols, fraction_estimates.n_cols, arma::fill::zeros); + arma::Mat log_fraction_estimates = arma::log(fraction_estimates); + arma::Mat temp_fraction_variance(fraction_estimates.n_cols, fraction_estimates.n_cols, arma::fill::zeros); // Calculate log-ratio variance for fraction estimates for (unsigned int i = 0; i < fraction_estimates.n_cols - 1; ++i) { for (unsigned int j = i + 1; j < fraction_estimates.n_cols; ++j) { // Calculate variance of log fractions - double col_variance = arma::var(log_fraction_estimates.col(i) - log_fraction_estimates.col(j)); + float col_variance = arma::var(log_fraction_estimates.col(i) - log_fraction_estimates.col(j)); // Add to matrix temp_fraction_variance(i, j) = col_variance; @@ -260,7 +260,7 @@ void FastSparIteration::calculate_fraction_log_ratio_variance() { // Calculate the basis variance void FastSparIteration::calculate_basis_variance() { // NOTE: We make a copy of the variance matrix here to restrict modifications outside this function - arma::Mat fraction_variance_munge = fraction_variance; + arma::Mat fraction_variance_munge = fraction_variance; // If any pairs have been excluded, set variance to zero if (!excluded_pairs.empty()) { @@ -268,7 +268,7 @@ void FastSparIteration::calculate_basis_variance() { } // Calculate the component variance - arma::Col component_variance= arma::sum(fraction_variance_munge, 1); + arma::Col component_variance= arma::sum(fraction_variance_munge, 1); // Solve Ax = b where A is the mod matrix and b is the component variance basis_variance = arma::solve(mod, component_variance, arma::solve_opts::fast); @@ -281,16 +281,16 @@ void FastSparIteration::calculate_basis_variance() { // Calculate the correlation and covariance void FastSparIteration::calculate_correlation_and_covariance(){ // Initialise matrices and vectors - std::vector basis_cor_diag(otu_table->otu_number, 1); - arma::Mat temp_basis_correlation = arma::diagmat((arma::Col) basis_cor_diag); - arma::Mat temp_basis_covariance = arma::diagmat((arma::Col) basis_variance); + std::vector basis_cor_diag(otu_table->otu_number, 1); + arma::Mat temp_basis_correlation = arma::diagmat((arma::Col) basis_cor_diag); + arma::Mat temp_basis_covariance = arma::diagmat((arma::Col) basis_variance); // Calculate correlation and covariance for each element add set in basis matrices for (int i = 0; i < otu_table->otu_number - 1; ++i) { for (int j = i + 1; j < otu_table->otu_number; ++j) { // Calculate cor and cov - double basis_cov_el = 0.5 * (basis_variance(i) + basis_variance(j) - fraction_variance(i, j)); - double basis_cor_el = basis_cov_el / std::sqrt(basis_variance(i)) / std::sqrt(basis_variance(j)); + float basis_cov_el = 0.5 * (basis_variance(i) + basis_variance(j) - fraction_variance(i, j)); + float basis_cor_el = basis_cov_el / std::sqrt(basis_variance(i)) / std::sqrt(basis_variance(j)); // Check if we got a valid correlation if (abs(basis_cor_el) > 1) { @@ -320,14 +320,14 @@ void FastSparIteration::calculate_correlation_and_covariance(){ // Find the highest correlation and exclude this pair if correlation is above threshold void FastSparIteration::find_and_exclude_pairs(float threshold) { // Set diagonal to zero as we're not interesting in excluding self-pairs and get absolute correlation value - arma::Mat basis_correlation_abs = arma::abs(basis_correlation); + arma::Mat basis_correlation_abs = arma::abs(basis_correlation); basis_correlation_abs.diag().zeros(); // Set previously excluded correlations to zero basis_correlation_abs((arma::Col)excluded_pairs).fill(0.0); // Get all elements with the max value - double max_correlate = basis_correlation_abs.max(); + float max_correlate = basis_correlation_abs.max(); arma::Col max_correlate_idx = arma::find(basis_correlation_abs == max_correlate); // If max correlation is above a threshold, subtract one from the appropriate mod matrix positions @@ -365,48 +365,48 @@ void FastSparIteration::find_and_exclude_pairs(float threshold) { void FastSpar::calculate_median_correlation_and_covariance() { // Get median of all i,j elements across the iterations for correlation // Add correlation matrices to arma Cube so that we can get views of all i, j of each matrix - arma::Cube correlation_cube(otu_table->otu_number, otu_table->otu_number, correlation_vector.size()); + arma::Cube correlation_cube(otu_table->otu_number, otu_table->otu_number, correlation_vector.size()); correlation_cube.fill(0.0); // Fill cube with correlation matrix slices int cube_slice = 0; - for (std::vector>::iterator it = correlation_vector.begin(); + for (std::vector>::iterator it = correlation_vector.begin(); it != correlation_vector.end(); ++it) { correlation_cube.slice(cube_slice) = *it; ++cube_slice; } // Get median value for each i, j element across all n iterations - arma::Mat temp_median_correlation(otu_table->otu_number, otu_table->otu_number); + arma::Mat temp_median_correlation(otu_table->otu_number, otu_table->otu_number); for (int i = 0; i < otu_table->otu_number; ++i) { for (int j = 0; j < otu_table->otu_number; ++j) { - arma::Row r = correlation_cube.subcube(arma::span(i), arma::span(j), arma::span()); + arma::Row r = correlation_cube.subcube(arma::span(i), arma::span(j), arma::span()); temp_median_correlation(i, j) = arma::median(r); } } // Get median for diagonal elements across iterations for covariance // Add covariance diagonals to arma Mat so that we can get row views for all i, j elements - arma::Mat covariance_diagonals(otu_table->otu_number, iterations); + arma::Mat covariance_diagonals(otu_table->otu_number, iterations); // Fill matrix will covariance diagonals int matrix_column = 0; - for (std::vector>::iterator it = covariance_vector.begin(); it != covariance_vector.end(); ++it) { + for (std::vector>::iterator it = covariance_vector.begin(); it != covariance_vector.end(); ++it) { covariance_diagonals.col(matrix_column) = *it; ++matrix_column; } // Get the median of each i, j element - arma::Col median_covariance_diag = arma::median(covariance_diagonals, 1); + arma::Col median_covariance_diag = arma::median(covariance_diagonals, 1); // Split into coordinate meshed grid - arma::Mat median_covariance_y(otu_table->otu_number, otu_table->otu_number); - arma::Mat median_covariance_x(otu_table->otu_number, otu_table->otu_number); + arma::Mat median_covariance_y(otu_table->otu_number, otu_table->otu_number); + arma::Mat median_covariance_x(otu_table->otu_number, otu_table->otu_number); median_covariance_y.each_col() = median_covariance_diag; - median_covariance_x.each_row() = arma::conv_to>::from(median_covariance_diag); + median_covariance_x.each_row() = arma::conv_to>::from(median_covariance_diag); // Calculate final median covariance - arma::Mat temp_median_covariance(otu_table->otu_number, otu_table->otu_number); + arma::Mat temp_median_covariance(otu_table->otu_number, otu_table->otu_number); temp_median_covariance = temp_median_correlation % arma::pow(median_covariance_x, 0.5) % arma::pow(median_covariance_y, 0.5); // Move the temp matrices to FastSpar diff --git a/src/fastspar.h b/src/fastspar.h index 5ae74cc..a6eec00 100644 --- a/src/fastspar.h +++ b/src/fastspar.h @@ -33,12 +33,12 @@ struct FastSpar { const OtuTable *otu_table; // List of each correlation and covariance matrix calculated during iterations - std::vector> correlation_vector; - std::vector> covariance_vector; + std::vector> correlation_vector; + std::vector> covariance_vector; // The median correlation covariance of all iterations - arma::Mat median_correlation; - arma::Mat median_covariance; + arma::Mat median_correlation; + arma::Mat median_covariance; // Construct FastSpar with a given otu_table and other parameters @@ -59,22 +59,22 @@ struct FastSparIteration { unsigned int exclusion_threshold; // Estimated fractions of OTUs - arma::Mat fraction_estimates; + arma::Mat fraction_estimates; // Variance of estimated OTU fractions - arma::Mat fraction_variance; + arma::Mat fraction_variance; // List of highly OTU pairs excluded in this iteration and number of components remaining std::vector excluded_pairs; unsigned int components_remaining; // Modifier matrix (lhs in dgesv) - arma::Mat mod; + arma::Mat mod; // Basis variance vector - arma::Col basis_variance; + arma::Col basis_variance; // Correlation and covariance for this iteration - arma::Mat basis_correlation; - arma::Mat basis_covariance; + arma::Mat basis_correlation; + arma::Mat basis_covariance; // Construct FastSparIterations with a given otu_table and other parameters diff --git a/src/pvalue.cpp b/src/pvalue.cpp index 23d71ec..ca9d6b3 100644 --- a/src/pvalue.cpp +++ b/src/pvalue.cpp @@ -25,8 +25,8 @@ std::vector get_bootstrap_correlation_paths(std::string glob_path) } -void count_values_more_extreme(arma::Mat &abs_observed_correlation, - arma::Mat &abs_bootstrap_correlation, +void count_values_more_extreme(arma::Mat &abs_observed_correlation, + arma::Mat &abs_bootstrap_correlation, arma::Mat &extreme_value_counts) { // Set diagonal to zero to avoid processing self pairs of OTUs abs_bootstrap_correlation.diag().zeros(); @@ -45,13 +45,13 @@ double factorial(double number) { } -double calculate_possbile_otu_permutations(std::unordered_map &count_frequency, int sample_number) { +double calculate_possbile_otu_permutations(std::unordered_map &count_frequency, int sample_number) { // The total permutations for a single OTU can be calculated by factorial division. We try to // simplify it here (ported from R code authored by Scott Ritchie) int max = 0; double numerator = 1; double denominator = 1; - for (std::unordered_map::iterator it = count_frequency.begin(); it != count_frequency.end(); ++it) { + for (std::unordered_map::iterator it = count_frequency.begin(); it != count_frequency.end(); ++it) { // Factorial of 1 is 1 if (it->second == 1){ continue; @@ -82,13 +82,13 @@ double calculate_possbile_otu_permutations(std::unordered_map &coun } -double calculate_exact_pvalue(double otu_pair_possible_permutations, int extreme_value_count, unsigned int permutations) { +float calculate_exact_pvalue(float otu_pair_possible_permutations, int extreme_value_count, unsigned int permutations) { // Function adapted and ported from statmod::permp - // This cast is messy (also can't pass otu_pair as double reference for some reason) - double prob[(int)otu_pair_possible_permutations]; - double prob_binom_sum = 0; + // This cast is messy (also can't pass otu_pair as float reference for some reason) + float prob[(int)otu_pair_possible_permutations]; + float prob_binom_sum = 0; for (int i = 0; i < otu_pair_possible_permutations; ++i) { - prob[i] = (double)(i + 1) / otu_pair_possible_permutations; + prob[i] = (float)(i + 1) / otu_pair_possible_permutations; } for (int i = 0; i < otu_pair_possible_permutations; ++i) { prob_binom_sum += gsl_cdf_binomial_P(extreme_value_count, prob[i], permutations); @@ -98,7 +98,7 @@ double calculate_exact_pvalue(double otu_pair_possible_permutations, int extreme } -double calculate_pvalue_with_integral_estimate(double otu_pair_possible_permutations, int extreme_value_count, unsigned int permutations) { +float calculate_pvalue_with_integral_estimate(float otu_pair_possible_permutations, int extreme_value_count, unsigned int permutations) { // Function adapted and ported from statmod::permp and statmod::gaussquad // TODO: See if there is a better way to init array elements w/o hard coding // Start statmod::gaussquad port @@ -137,20 +137,20 @@ double calculate_pvalue_with_integral_estimate(double otu_pair_possible_permutat // End statmod::gaussquad port // Start statmod::permp port - double weight_prob_product_sum = 0; + float weight_prob_product_sum = 0; for (int i = 0; i < n; ++i) { weight_prob_product_sum += gsl_cdf_binomial_P(extreme_value_count, nodes[i], permutations) * weights[i]; } - double integral = 0.5 / (otu_pair_possible_permutations * weight_prob_product_sum); - // TODO: Check if the double cast correctly done (it is required but maybe adding 1.0 instead of 1 is sufficient) + float integral = 0.5 / (otu_pair_possible_permutations * weight_prob_product_sum); + // TODO: Check if the float cast correctly done (it is required but maybe adding 1.0 instead of 1 is sufficient) // End statmod::permp port with p-value return - return ((double)extreme_value_count + 1) / ((double)permutations + 1) - integral; + return ((float)extreme_value_count + 1) / ((float)permutations + 1) - integral; } -arma::Mat calculate_pvalues(OtuTable &otu_table, arma::Mat &observed_correlation, std::vector &bootstrap_correlation_fps, unsigned int permutations, bool exact, unsigned int threads) { +arma::Mat calculate_pvalues(OtuTable &otu_table, arma::Mat &observed_correlation, std::vector &bootstrap_correlation_fps, unsigned int permutations, bool exact, unsigned int threads) { // Get absolute correlations - arma::Mat abs_observed_correlation = arma::abs(observed_correlation); + arma::Mat abs_observed_correlation = arma::abs(observed_correlation); // OpenMP function from omp.h. This sets the number of threads in a more reliable way but also ignores OMP_NUM_THREADS omp_set_num_threads(threads); @@ -163,25 +163,25 @@ arma::Mat calculate_pvalues(OtuTable &otu_table, arma::Mat &obse for (int unsigned i = 0; i < bootstrap_correlation_fps.size(); ++i) { printf("\tBootstrap correlation %i: %s\n", i, bootstrap_correlation_fps[i].c_str()); // Load the bootstrap correlation and get absolute values - arma::Mat bootstrap_correlation = load_correlation_file(bootstrap_correlation_fps[i]); - arma::Mat abs_bootstrap_correlation = arma::abs(bootstrap_correlation); + arma::Mat bootstrap_correlation = load_correlation_file(bootstrap_correlation_fps[i]); + arma::Mat abs_bootstrap_correlation = arma::abs(bootstrap_correlation); // Count if bootstrap correlation is greater than observed correlation count_values_more_extreme(abs_observed_correlation, abs_bootstrap_correlation, extreme_value_counts); } - arma::Mat pvalues(otu_table.otu_number, otu_table.otu_number); + arma::Mat pvalues(otu_table.otu_number, otu_table.otu_number); if (exact) { // Calculate total possible permutations for each OTU printf("Calculating %i total permutations\n", otu_table.otu_number); - arma::Col possible_permutations(otu_table.otu_number, arma::fill::zeros); + arma::Col possible_permutations(otu_table.otu_number, arma::fill::zeros); for (int i = 0; i < otu_table.otu_number; ++i) { printf("\tTotal permutation %i of %d\n", i, otu_table.otu_number); // First we need to get the frequency of each count for an OTU across all samples - // TODO: Check that after changing from int counts to double (for corrected OTU counts) that this isn't borked + // TODO: Check that after changing from int counts to float (for corrected OTU counts) that this isn't borked // Main concern that equality will not be true in some cases where they would be otherwise due to float error - std::unordered_map count_frequency; - for (arma::Mat::col_iterator it = otu_table.counts.begin_col(i); it != otu_table.counts.end_col(i); ++it) { + std::unordered_map count_frequency; + for (arma::Mat::col_iterator it = otu_table.counts.begin_col(i); it != otu_table.counts.end_col(i); ++it) { ++count_frequency[*it]; } @@ -202,11 +202,11 @@ arma::Mat calculate_pvalues(OtuTable &otu_table, arma::Mat &obse } // Get the total possible permutations between the OTU pair // TODO: Check if this is producing desired results - double otu_pair_possible_permutations = possible_permutations(i) * possible_permutations(j); + float otu_pair_possible_permutations = possible_permutations(i) * possible_permutations(j); // The 'if' code in block below was ported from statmod::permp and statmod::gaussquad if (otu_pair_possible_permutations <= 10000 ) { // Exact p-value calculation - // If fewer than 10000 possible permutations, we can safely cast double to int + // If fewer than 10000 possible permutations, we can safely cast float to int #pragma omp atomic write pvalues(i, j) = calculate_exact_pvalue((int)otu_pair_possible_permutations, extreme_value_counts(i, j), permutations); } else { @@ -217,7 +217,7 @@ arma::Mat calculate_pvalues(OtuTable &otu_table, arma::Mat &obse } } } else { - pvalues = arma::conv_to>::from(extreme_value_counts) / permutations; + pvalues = arma::conv_to>::from(extreme_value_counts) / permutations; } pvalues.diag().ones(); return pvalues; @@ -244,10 +244,10 @@ int main(int argc, char **argv) { // Read in observed correlation printf("Reading in observed correlations\n"); - arma::Mat observed_correlation = load_correlation_file(pval_options.correlation_filename); + arma::Mat observed_correlation = load_correlation_file(pval_options.correlation_filename); // Calulate p-values - arma::Mat pvalues = calculate_pvalues(otu_table, observed_correlation, bs_cor_paths, + arma::Mat pvalues = calculate_pvalues(otu_table, observed_correlation, bs_cor_paths, pval_options.permutations, pval_options.exact, pval_options.threads); // Write out p-values diff --git a/src/pvalue.h b/src/pvalue.h index 9422fe0..d6a6a9d 100644 --- a/src/pvalue.h +++ b/src/pvalue.h @@ -20,22 +20,22 @@ std::vector get_bootstrap_correlation_paths(std::string glob_path); // Calculate p-values for all OTU pairs -arma::Mat calculate_pvalues(OtuTable &otu_table, arma::Mat &observed_correlation, std::vector &bootstrap_correlation_fps, unsigned int permutations, bool exact, unsigned int threads); +arma::Mat calculate_pvalues(OtuTable &otu_table, arma::Mat &observed_correlation, std::vector &bootstrap_correlation_fps, unsigned int permutations, bool exact, unsigned int threads); // Count bootstrap correlations more extreme than the observed correlation -void count_values_more_extreme(arma::Mat &abs_observed_correlation, arma::Mat &abs_bootstrap_correlation, arma::Mat &extreme_value_counts); +void count_values_more_extreme(arma::Mat &abs_observed_correlation, arma::Mat &abs_bootstrap_correlation, arma::Mat &extreme_value_counts); // Calculate factorial double factorial(double number); // Calculate the permutations of count data for an OTU -double calculate_possbile_otu_permutations(std::unordered_map &count_frequency, int sample_number); +double calculate_possbile_otu_permutations(std::unordered_map &count_frequency, int sample_number); // Calculate the exact p-value for an OTU pair (when total nperm <= 10000) -double calculate_exact_pvalue(double otu_pair_possible_permutations, int extreme_value_count, unsigned int permutations); +float calculate_exact_pvalue(float otu_pair_possible_permutations, int extreme_value_count, unsigned int permutations); // Calculate the p-value with an integral estimate for an OTU pair (when total nperm > 10000) -double calculate_pvalue_with_integral_estimate(double otu_pair_possible_permutations, int extreme_value_count, int permutations); +float calculate_pvalue_with_integral_estimate(float otu_pair_possible_permutations, int extreme_value_count, int permutations); #endif diff --git a/src/reduce.cpp b/src/reduce.cpp index b108432..00618a0 100644 --- a/src/reduce.cpp +++ b/src/reduce.cpp @@ -19,7 +19,7 @@ SquareMatrix load_square_matrix(std::string filename) { element_count = std::count(line.begin(), line.end(), '\t'); /* Now read in elements of sequare matrix */ - std::vector element_vector; + std::vector element_vector; std::vector otu_vector; element_vector.reserve(element_count * element_count); otu_vector.reserve(element_count); @@ -39,7 +39,7 @@ SquareMatrix load_square_matrix(std::string filename) { id = false; continue; } - // Add current element to correlation mat after converting to double + // Add current element to correlation mat after converting to float element_vector.push_back(std::stod(ele)); } } @@ -52,7 +52,7 @@ SquareMatrix load_square_matrix(std::string filename) { SparseMatrix filter_matrix(SquareMatrix table, arma::Col filtered_element_indices) { // Extract elements which pass filter - arma::Col filtered_table = table.elements(filtered_element_indices); + arma::Col filtered_table = table.elements(filtered_element_indices); // Record col and row names for OTUs elements in vector std::vector> otus(filtered_element_indices.n_elem, std::vector(2)); diff --git a/src/reduce.h b/src/reduce.h index d6204d2..d7e4ac2 100644 --- a/src/reduce.h +++ b/src/reduce.h @@ -12,17 +12,17 @@ // Structure to hold information about a square matrix struct SquareMatrix { - arma::Mat elements; + arma::Mat elements; std::vector otus; // Using initaliser lists to enable calling on Mat/vector constructors at the time of struct construction - SquareMatrix(std::vector element_vector, std::vector otu_vector) : elements(element_vector), otus(otu_vector) {} + SquareMatrix(std::vector element_vector, std::vector otu_vector) : elements(element_vector), otus(otu_vector) {} }; // Structure to hold information about a sparse matrix struct SparseMatrix { - arma::Col elements; + arma::Col elements; std::vector> otus; }; diff --git a/tests/test_fastspar.cpp b/tests/test_fastspar.cpp index a60c907..b78356e 100644 --- a/tests/test_fastspar.cpp +++ b/tests/test_fastspar.cpp @@ -20,8 +20,8 @@ TEST_CASE("Correlation, covariance statisitc integration test") { fastspar.calculate_median_correlation_and_covariance(); // Compare to previously validated outputs - arma::Mat correlation = load_correlation_file(correlation_fp); - arma::Mat covariance = load_correlation_file(covariance_fp); + arma::Mat correlation = load_correlation_file(correlation_fp); + arma::Mat covariance = load_correlation_file(covariance_fp); // Tolerating difference of 0.0001 as output types are rounded REQUIRE(arma::approx_equal(fastspar.median_correlation, correlation, "absdiff", 0.001)); diff --git a/tests/test_pvalues.cpp b/tests/test_pvalues.cpp index 8253fc9..8c12ecc 100644 --- a/tests/test_pvalues.cpp +++ b/tests/test_pvalues.cpp @@ -22,15 +22,15 @@ TEST_CASE("p-value integration test") { // Load data used in calculation OtuTable otu_table; otu_table.load_otu_file(otu_fp); - arma::Mat observed_correlation = load_correlation_file(observed_correlation_fp); + arma::Mat observed_correlation = load_correlation_file(observed_correlation_fp); // Calculate pvalues - arma::Mat test_pseudo_pvalues = calculate_pvalues(otu_table, observed_correlation, bs_correlation_fps, 3, false, 1); - arma::Mat test_exact_pvalues = calculate_pvalues(otu_table, observed_correlation, bs_correlation_fps, 3, true, 1); + arma::Mat test_pseudo_pvalues = calculate_pvalues(otu_table, observed_correlation, bs_correlation_fps, 3, false, 1); + arma::Mat test_exact_pvalues = calculate_pvalues(otu_table, observed_correlation, bs_correlation_fps, 3, true, 1); // Load data for comparison (pseudo p-values from SparCC, exact previously validated) - arma::Mat pseudo_pvalues = load_correlation_file(pseudo_fp); - arma::Mat exact_pvalues = load_correlation_file(exact_fp); + arma::Mat pseudo_pvalues = load_correlation_file(pseudo_fp); + arma::Mat exact_pvalues = load_correlation_file(exact_fp); // Tolerating difference of 0.0001 as output types are rounded REQUIRE(arma::approx_equal(test_pseudo_pvalues, pseudo_pvalues, "absdiff", 0.001)); @@ -50,7 +50,7 @@ TEST_CASE("factorial division") { // required is cancelation of factorials during division to avoid overflow // Data int sample_number = 32; - std::unordered_map count_frequency = {{10, 1}, {0, 20}, {2, 5}, {3, 6}}; - double result = calculate_possbile_otu_permutations(count_frequency, sample_number); + std::unordered_map count_frequency = {{10, 1}, {0, 20}, {2, 5}, {3, 6}}; + float result = calculate_possbile_otu_permutations(count_frequency, sample_number); REQUIRE(1251795504960 == result); }