-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from pravirkr/ddmt
Initial direct dedisp
- Loading branch information
Showing
10 changed files
with
420 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
#pragma once | ||
|
||
#include <cstddef> | ||
#include <vector> | ||
|
||
using SizeType = std::size_t; | ||
|
||
struct DDMTPlan { | ||
std::vector<float> dm_arr; | ||
// ndm x nchan | ||
std::vector<SizeType> delay_table; | ||
size_t nchans; | ||
}; | ||
|
||
class DDMT { | ||
public: | ||
DDMT(float f_min, | ||
float f_max, | ||
SizeType nchans, | ||
float tsamp, | ||
float dm_max, | ||
float dm_step, | ||
float dm_min = 0.0F); | ||
|
||
DDMT(float f_min, | ||
float f_max, | ||
SizeType nchans, | ||
float tsamp, | ||
const float* dm_arr, | ||
SizeType dm_count); | ||
|
||
DDMT(const DDMT&) = delete; | ||
DDMT& operator=(const DDMT&) = delete; | ||
DDMT(DDMT&&) = delete; | ||
DDMT& operator=(DDMT&&) = delete; | ||
virtual ~DDMT() = default; | ||
|
||
const DDMTPlan& get_plan() const; | ||
std::vector<float> get_dm_grid() const; | ||
static void set_log_level(int level); | ||
virtual void execute(const float* __restrict waterfall, | ||
SizeType waterfall_size, | ||
float* __restrict dmt, | ||
SizeType dmt_size) = 0; | ||
|
||
private: | ||
float m_f_min; | ||
float m_f_max; | ||
SizeType m_nchans; | ||
float m_tsamp; | ||
std::vector<float> m_dm_arr; | ||
|
||
DDMTPlan m_ddmt_plan; | ||
|
||
void validate_inputs() const; | ||
void configure_ddmt_plan(); | ||
static std::vector<float> | ||
generate_dm_arr(float dm_max, float dm_step, float dm_min); | ||
static std::vector<float> generate_dm_arr(const float* dm_arr, | ||
SizeType dm_count); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#pragma once | ||
|
||
#include <dmt/ddmt_base.hpp> | ||
|
||
class DDMTCPU : public DDMT { | ||
public: | ||
DDMTCPU(float f_min, | ||
float f_max, | ||
SizeType nchans, | ||
float tsamp, | ||
float dm_max, | ||
float dm_step, | ||
float dm_min = 0.0F); | ||
|
||
DDMTCPU(float f_min, | ||
float f_max, | ||
SizeType nchans, | ||
float tsamp, | ||
const float* dm_arr, | ||
SizeType dm_count); | ||
|
||
static void set_num_threads(int nthreads); | ||
void execute(const float* __restrict waterfall, | ||
SizeType waterfall_size, | ||
float* __restrict dmt, | ||
SizeType dmt_size) override; | ||
|
||
private: | ||
static void execute_dedisp(const float* __restrict__ d_in, | ||
size_t in_chan_stride, | ||
size_t in_samp_stride, | ||
float* __restrict__ d_out, | ||
size_t out_dm_stride, | ||
size_t out_samp_stride, | ||
const size_t* __restrict__ delay_table, | ||
size_t dm_count, | ||
size_t nchans, | ||
size_t nsamps_reduced); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#pragma once | ||
|
||
#include <thrust/device_vector.h> | ||
#include <thrust/host_vector.h> | ||
|
||
#include <dmt/ddmt_base.hpp> | ||
|
||
struct DDMTPlanD { | ||
thrust::device_vector<float> dm_arr_d; | ||
thrust::device_vector<int> delay_arr_d; | ||
thrust::device_vector<int> kill_mask_d; | ||
}; | ||
|
||
class DDMTGPU : public DDMT { | ||
public: | ||
DDMTGPU(float f_min, | ||
float f_max, | ||
SizeType nchans, | ||
SizeType nsamps, | ||
float tsamp, | ||
float dm_max, | ||
float dm_step, | ||
float dm_min = 0.0F); | ||
void execute(const float* __restrict waterfall, | ||
SizeType waterfall_size, | ||
float* __restrict dmt, | ||
SizeType dmt_size) override; | ||
|
||
void execute(const float* __restrict waterfall, | ||
SizeType waterfall_size, | ||
float* __restrict dmt, | ||
SizeType dmt_size, | ||
bool device_flags); | ||
|
||
private: | ||
DDMTPlanD m_ddmt_plan_d; | ||
|
||
static void transfer_plan_to_device(const DDMTPlan& plan, | ||
DDMTPlanD& plan_d); | ||
void execute_device(const float* __restrict waterfall, | ||
SizeType waterfall_size, | ||
float* __restrict dmt, | ||
SizeType dmt_size); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
#include <stdexcept> | ||
|
||
#include <spdlog/spdlog.h> | ||
|
||
#include <dmt/ddmt_base.hpp> | ||
#include <dmt/fdmt_utils.hpp> | ||
|
||
DDMT::DDMT(float f_min, | ||
float f_max, | ||
SizeType nchans, | ||
float tsamp, | ||
float dm_max, | ||
float dm_step, | ||
float dm_min) | ||
: m_f_min(f_min), | ||
m_f_max(f_max), | ||
m_nchans(nchans), | ||
m_tsamp(tsamp), | ||
m_dm_arr(generate_dm_arr(dm_max, dm_step, dm_min)) { | ||
validate_inputs(); | ||
configure_ddmt_plan(); | ||
spdlog::debug("DDMT: dm_max={}, dm_min={}, dm_step={}", dm_max, dm_min, | ||
dm_step); | ||
} | ||
|
||
DDMT::DDMT(float f_min, | ||
float f_max, | ||
SizeType nchans, | ||
float tsamp, | ||
const float* dm_arr, | ||
SizeType dm_count) | ||
: m_f_min(f_min), | ||
m_f_max(f_max), | ||
m_nchans(nchans), | ||
m_tsamp(tsamp), | ||
m_dm_arr(generate_dm_arr(dm_arr, dm_count)) { | ||
validate_inputs(); | ||
configure_ddmt_plan(); | ||
spdlog::debug("DDMT: dm_count={}", dm_count); | ||
} | ||
|
||
const DDMTPlan& DDMT::get_plan() const { return m_ddmt_plan; } | ||
|
||
std::vector<float> DDMT::get_dm_grid() const { return m_ddmt_plan.dm_arr; } | ||
|
||
void DDMT::set_log_level(int level) { | ||
if (level < static_cast<int>(spdlog::level::trace) || | ||
level > static_cast<int>(spdlog::level::off)) { | ||
spdlog::set_level(spdlog::level::info); | ||
} | ||
spdlog::set_level(static_cast<spdlog::level::level_enum>(level)); | ||
} | ||
|
||
void DDMT::configure_ddmt_plan() { | ||
m_ddmt_plan.nchans = m_nchans; | ||
m_ddmt_plan.dm_arr = m_dm_arr; | ||
const auto df = (m_f_max - m_f_min) / static_cast<float>(m_nchans); | ||
m_ddmt_plan.delay_table = ddmt::generate_delay_table( | ||
m_dm_arr.data(), m_dm_arr.size(), m_f_min, df, m_nchans, m_tsamp); | ||
} | ||
|
||
void DDMT::validate_inputs() const { | ||
if (m_f_min >= m_f_max) { | ||
throw std::invalid_argument("f_min must be less than f_max"); | ||
} | ||
if (m_tsamp <= 0) { | ||
throw std::invalid_argument("tsamp must be greater than 0"); | ||
} | ||
if (m_nchans <= 0) { | ||
throw std::invalid_argument("nchans must be greater than 0"); | ||
} | ||
if (m_dm_arr.empty()) { | ||
throw std::invalid_argument("dm_arr must not be empty"); | ||
} | ||
} | ||
|
||
std::vector<float> | ||
DDMT::generate_dm_arr(float dm_max, float dm_step, float dm_min) { | ||
std::vector<float> dm_arr; | ||
for (float dm = dm_min; dm <= dm_max; dm += dm_step) { | ||
dm_arr.push_back(dm); | ||
} | ||
return dm_arr; | ||
} | ||
|
||
std::vector<float> DDMT::generate_dm_arr(const float* dm_arr, | ||
SizeType dm_count) { | ||
std::vector<float> dm_vec(dm_arr, dm_arr + dm_count); | ||
return dm_vec; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
#include "dmt/ddmt_base.hpp" | ||
#include <cmath> | ||
#include <cstddef> | ||
#include <spdlog/spdlog.h> | ||
#ifdef USE_OPENMP | ||
#include <omp.h> | ||
#endif | ||
|
||
#include <dmt/ddmt_cpu.hpp> | ||
|
||
DDMTCPU::DDMTCPU(float f_min, | ||
float f_max, | ||
SizeType nchans, | ||
float tsamp, | ||
float dm_max, | ||
float dm_step, | ||
float dm_min) | ||
: DDMT(f_min, f_max, nchans, tsamp, dm_max, dm_step, dm_min) {} | ||
|
||
DDMTCPU::DDMTCPU(float f_min, | ||
float f_max, | ||
SizeType nchans, | ||
float tsamp, | ||
const float* dm_arr, | ||
SizeType dm_count) | ||
: DDMT(f_min, f_max, nchans, tsamp, dm_arr, dm_count) {} | ||
|
||
void DDMTCPU::set_num_threads(int nthreads) { | ||
#ifdef USE_OPENMP | ||
omp_set_num_threads(nthreads); | ||
#endif | ||
} | ||
|
||
void DDMTCPU::execute(const float* __restrict waterfall, | ||
SizeType waterfall_size, | ||
float* __restrict dmt, | ||
SizeType dmt_size) { | ||
const auto& plan = get_plan(); | ||
const auto nchans = plan.nchans; | ||
const auto nsamps = waterfall_size / nchans; | ||
const auto max_delay = plan.delay_table.back(); | ||
const auto nsamps_reduced = nsamps - max_delay; | ||
const auto out_dm_stride = nsamps_reduced; | ||
const auto out_samp_stride = 1; | ||
const auto in_chan_stride = nsamps; | ||
const auto in_samp_stride = 1; | ||
const auto* delay_table = plan.delay_table.data(); | ||
const auto dm_count = plan.dm_arr.size(); | ||
|
||
if (dmt_size != dm_count * nsamps_reduced) { | ||
spdlog::error("Output buffer size mismatch: expected {}, got {}", | ||
dm_count * nsamps_reduced, dmt_size); | ||
return; | ||
} | ||
|
||
execute_dedisp(waterfall, in_chan_stride, in_samp_stride, dmt, | ||
out_dm_stride, out_samp_stride, delay_table, dm_count, | ||
nchans, nsamps_reduced); | ||
} | ||
|
||
void DDMTCPU::execute_dedisp(const float* __restrict__ d_in, | ||
size_t in_chan_stride, | ||
size_t in_samp_stride, | ||
float* __restrict__ d_out, | ||
size_t out_dm_stride, | ||
size_t out_samp_stride, | ||
const size_t* __restrict__ delay_table, | ||
size_t dm_count, | ||
size_t nchans, | ||
size_t nsamps_reduced) { | ||
#pragma omp parallel for default(none) \ | ||
shared(d_in, d_out, delay_table, dm_count, nchans, nsamps_reduced, \ | ||
in_chan_stride, in_samp_stride, out_dm_stride, out_samp_stride) | ||
for (size_t i_dm = 0; i_dm < dm_count; ++i_dm) { | ||
const auto& delays = &delay_table[i_dm * nchans]; | ||
const auto out_idx = i_dm * out_dm_stride; | ||
for (size_t i_samp = 0; i_samp < nsamps_reduced; ++i_samp) { | ||
float sum = 0.0F; | ||
#pragma omp simd reduction(+ : sum) | ||
for (size_t i_chan = 0; i_chan < nchans; ++i_chan) { | ||
const auto& delay = delays[i_chan]; | ||
sum += d_in[i_chan * in_chan_stride + | ||
(i_samp + delay) * in_samp_stride]; | ||
} | ||
d_out[out_idx + i_samp * out_samp_stride] = sum; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
#pragma once | ||
|
||
#include <cuda_runtime.h> | ||
#include <sstream> | ||
#include <stdexcept> | ||
#include <string> | ||
|
||
namespace error_checker { | ||
|
||
inline void | ||
check_cuda_error(const char* file, int line, const std::string& msg = "") { | ||
cudaError_t error = cudaGetLastError(); | ||
if (error != cudaSuccess) { | ||
std::stringstream error_msg; | ||
error_msg << "CUDA failed with error: " << cudaGetErrorString(error) | ||
<< " (" << file << ":" << line << ")"; | ||
if (!msg.empty()) { | ||
error_msg << " - " << msg; | ||
} | ||
throw std::runtime_error(error_msg.str()); | ||
} | ||
} | ||
|
||
inline void | ||
check_cuda_error_sync(const char* file, int line, const std::string& msg = "") { | ||
cudaDeviceSynchronize(); | ||
check_cuda_error(file, line, msg); | ||
} | ||
|
||
void check_cuda(const std::string& msg = "") { | ||
check_cuda_error(__FILE__, __LINE__, msg); | ||
} | ||
|
||
void check_cuda_sync(const std::string& msg = "") { | ||
check_cuda_error_sync(__FILE__, __LINE__, msg); | ||
} | ||
|
||
} // namespace error_checker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.