Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial direct dedisp #9

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions include/dmt/ddmt_base.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once

#include <cstddef>
#include <vector>

using SizeType = std::size_t;

struct DDMTPlan {
std::vector<float> dm_arr;
// ndm x nchan
std::vector<SizeType> delay_table;
size_t nchans;
};

class DDMT {
public:
DDMT(float f_min,
float f_max,
SizeType nchans,
float tsamp,
float dm_max,
float dm_step,
float dm_min = 0.0F);

DDMT(float f_min,
float f_max,
SizeType nchans,
float tsamp,
const float* dm_arr,
SizeType dm_count);

DDMT(const DDMT&) = delete;
DDMT& operator=(const DDMT&) = delete;
DDMT(DDMT&&) = delete;
DDMT& operator=(DDMT&&) = delete;
virtual ~DDMT() = default;

const DDMTPlan& get_plan() const;
std::vector<float> get_dm_grid() const;
static void set_log_level(int level);
virtual void execute(const float* __restrict waterfall,
SizeType waterfall_size,
float* __restrict dmt,
SizeType dmt_size) = 0;

private:
float m_f_min;
float m_f_max;
SizeType m_nchans;
float m_tsamp;
std::vector<float> m_dm_arr;

DDMTPlan m_ddmt_plan;

void validate_inputs() const;
void configure_ddmt_plan();
static std::vector<float>
generate_dm_arr(float dm_max, float dm_step, float dm_min);
static std::vector<float> generate_dm_arr(const float* dm_arr,
SizeType dm_count);
};
39 changes: 39 additions & 0 deletions include/dmt/ddmt_cpu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include <dmt/ddmt_base.hpp>

class DDMTCPU : public DDMT {
public:
DDMTCPU(float f_min,
float f_max,
SizeType nchans,
float tsamp,
float dm_max,
float dm_step,
float dm_min = 0.0F);

DDMTCPU(float f_min,
float f_max,
SizeType nchans,
float tsamp,
const float* dm_arr,
SizeType dm_count);

static void set_num_threads(int nthreads);
void execute(const float* __restrict waterfall,
SizeType waterfall_size,
float* __restrict dmt,
SizeType dmt_size) override;

private:
static void execute_dedisp(const float* __restrict__ d_in,
size_t in_chan_stride,
size_t in_samp_stride,
float* __restrict__ d_out,
size_t out_dm_stride,
size_t out_samp_stride,
const size_t* __restrict__ delay_table,
size_t dm_count,
size_t nchans,
size_t nsamps_reduced);
};
44 changes: 44 additions & 0 deletions include/dmt/ddmt_gpu.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#pragma once

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

#include <dmt/ddmt_base.hpp>

struct DDMTPlanD {
thrust::device_vector<float> dm_arr_d;
thrust::device_vector<int> delay_arr_d;
thrust::device_vector<int> kill_mask_d;
};

class DDMTGPU : public DDMT {
public:
DDMTGPU(float f_min,
float f_max,
SizeType nchans,
SizeType nsamps,
float tsamp,
float dm_max,
float dm_step,
float dm_min = 0.0F);
void execute(const float* __restrict waterfall,
SizeType waterfall_size,
float* __restrict dmt,
SizeType dmt_size) override;

void execute(const float* __restrict waterfall,
SizeType waterfall_size,
float* __restrict dmt,
SizeType dmt_size,
bool device_flags);

private:
DDMTPlanD m_ddmt_plan_d;

static void transfer_plan_to_device(const DDMTPlan& plan,
DDMTPlanD& plan_d);
void execute_device(const float* __restrict waterfall,
SizeType waterfall_size,
float* __restrict dmt,
SizeType dmt_size);
};
90 changes: 90 additions & 0 deletions lib/ddmt_base.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#include <stdexcept>

#include <spdlog/spdlog.h>

#include <dmt/ddmt_base.hpp>
#include <dmt/fdmt_utils.hpp>

DDMT::DDMT(float f_min,
float f_max,
SizeType nchans,
float tsamp,
float dm_max,
float dm_step,
float dm_min)
: m_f_min(f_min),
m_f_max(f_max),
m_nchans(nchans),
m_tsamp(tsamp),
m_dm_arr(generate_dm_arr(dm_max, dm_step, dm_min)) {
validate_inputs();
configure_ddmt_plan();
spdlog::debug("DDMT: dm_max={}, dm_min={}, dm_step={}", dm_max, dm_min,
dm_step);
}

DDMT::DDMT(float f_min,
float f_max,
SizeType nchans,
float tsamp,
const float* dm_arr,
SizeType dm_count)
: m_f_min(f_min),
m_f_max(f_max),
m_nchans(nchans),
m_tsamp(tsamp),
m_dm_arr(generate_dm_arr(dm_arr, dm_count)) {
validate_inputs();
configure_ddmt_plan();
spdlog::debug("DDMT: dm_count={}", dm_count);
}

const DDMTPlan& DDMT::get_plan() const { return m_ddmt_plan; }

std::vector<float> DDMT::get_dm_grid() const { return m_ddmt_plan.dm_arr; }

void DDMT::set_log_level(int level) {
if (level < static_cast<int>(spdlog::level::trace) ||
level > static_cast<int>(spdlog::level::off)) {
spdlog::set_level(spdlog::level::info);
}
spdlog::set_level(static_cast<spdlog::level::level_enum>(level));
}

void DDMT::configure_ddmt_plan() {
m_ddmt_plan.nchans = m_nchans;
m_ddmt_plan.dm_arr = m_dm_arr;
const auto df = (m_f_max - m_f_min) / static_cast<float>(m_nchans);
m_ddmt_plan.delay_table = ddmt::generate_delay_table(
m_dm_arr.data(), m_dm_arr.size(), m_f_min, df, m_nchans, m_tsamp);
}

void DDMT::validate_inputs() const {
if (m_f_min >= m_f_max) {
throw std::invalid_argument("f_min must be less than f_max");
}
if (m_tsamp <= 0) {
throw std::invalid_argument("tsamp must be greater than 0");
}
if (m_nchans <= 0) {
throw std::invalid_argument("nchans must be greater than 0");
}
if (m_dm_arr.empty()) {
throw std::invalid_argument("dm_arr must not be empty");
}
}

std::vector<float>
DDMT::generate_dm_arr(float dm_max, float dm_step, float dm_min) {
std::vector<float> dm_arr;
for (float dm = dm_min; dm <= dm_max; dm += dm_step) {
dm_arr.push_back(dm);
}
return dm_arr;
}

std::vector<float> DDMT::generate_dm_arr(const float* dm_arr,
SizeType dm_count) {
std::vector<float> dm_vec(dm_arr, dm_arr + dm_count);
return dm_vec;
}
88 changes: 88 additions & 0 deletions lib/ddmt_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include "dmt/ddmt_base.hpp"
#include <cmath>
#include <cstddef>
#include <spdlog/spdlog.h>
#ifdef USE_OPENMP
#include <omp.h>
#endif

#include <dmt/ddmt_cpu.hpp>

DDMTCPU::DDMTCPU(float f_min,
float f_max,
SizeType nchans,
float tsamp,
float dm_max,
float dm_step,
float dm_min)
: DDMT(f_min, f_max, nchans, tsamp, dm_max, dm_step, dm_min) {}

DDMTCPU::DDMTCPU(float f_min,
float f_max,
SizeType nchans,
float tsamp,
const float* dm_arr,
SizeType dm_count)
: DDMT(f_min, f_max, nchans, tsamp, dm_arr, dm_count) {}

void DDMTCPU::set_num_threads(int nthreads) {
#ifdef USE_OPENMP
omp_set_num_threads(nthreads);
#endif
}

void DDMTCPU::execute(const float* __restrict waterfall,
SizeType waterfall_size,
float* __restrict dmt,
SizeType dmt_size) {
const auto& plan = get_plan();
const auto nchans = plan.nchans;
const auto nsamps = waterfall_size / nchans;
const auto max_delay = plan.delay_table.back();
const auto nsamps_reduced = nsamps - max_delay;
const auto out_dm_stride = nsamps_reduced;
const auto out_samp_stride = 1;
const auto in_chan_stride = nsamps;
const auto in_samp_stride = 1;
const auto* delay_table = plan.delay_table.data();
const auto dm_count = plan.dm_arr.size();

if (dmt_size != dm_count * nsamps_reduced) {
spdlog::error("Output buffer size mismatch: expected {}, got {}",
dm_count * nsamps_reduced, dmt_size);
return;
}

execute_dedisp(waterfall, in_chan_stride, in_samp_stride, dmt,
out_dm_stride, out_samp_stride, delay_table, dm_count,
nchans, nsamps_reduced);
}

void DDMTCPU::execute_dedisp(const float* __restrict__ d_in,
size_t in_chan_stride,
size_t in_samp_stride,
float* __restrict__ d_out,
size_t out_dm_stride,
size_t out_samp_stride,
const size_t* __restrict__ delay_table,
size_t dm_count,
size_t nchans,
size_t nsamps_reduced) {
#pragma omp parallel for default(none) \
shared(d_in, d_out, delay_table, dm_count, nchans, nsamps_reduced, \
in_chan_stride, in_samp_stride, out_dm_stride, out_samp_stride)
for (size_t i_dm = 0; i_dm < dm_count; ++i_dm) {
const auto& delays = &delay_table[i_dm * nchans];
const auto out_idx = i_dm * out_dm_stride;
for (size_t i_samp = 0; i_samp < nsamps_reduced; ++i_samp) {
float sum = 0.0F;
#pragma omp simd reduction(+ : sum)
for (size_t i_chan = 0; i_chan < nchans; ++i_chan) {
const auto& delay = delays[i_chan];
sum += d_in[i_chan * in_chan_stride +
(i_samp + delay) * in_samp_stride];
}
d_out[out_idx + i_samp * out_samp_stride] = sum;
}
}
}
38 changes: 38 additions & 0 deletions lib/dmt/cuda_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#pragma once

#include <cuda_runtime.h>
#include <sstream>
#include <stdexcept>
#include <string>

namespace error_checker {

inline void
check_cuda_error(const char* file, int line, const std::string& msg = "") {
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) {
std::stringstream error_msg;
error_msg << "CUDA failed with error: " << cudaGetErrorString(error)
<< " (" << file << ":" << line << ")";
if (!msg.empty()) {
error_msg << " - " << msg;
}
throw std::runtime_error(error_msg.str());
}
}

inline void
check_cuda_error_sync(const char* file, int line, const std::string& msg = "") {
cudaDeviceSynchronize();
check_cuda_error(file, line, msg);
}

void check_cuda(const std::string& msg = "") {
check_cuda_error(__FILE__, __LINE__, msg);
}

void check_cuda_sync(const std::string& msg = "") {
check_cuda_error_sync(__FILE__, __LINE__, msg);
}

} // namespace error_checker
9 changes: 9 additions & 0 deletions lib/dmt/fdmt_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ constexpr float kDispConst = kDispConstMT;

using SizeType = std::size_t;

namespace ddmt {
std::vector<SizeType> generate_delay_table(const float* dm_arr,
SizeType dm_count,
float f0,
float df,
SizeType nchans,
float tsamp);
}

namespace fdmt {

float cff(float f1_start, float f1_end, float f2_start, float f2_end);
Expand Down
Loading
Loading