Skip to content

Commit

Permalink
Initial edits; code compiles but the solver is not yet working properly
Browse files Browse the repository at this point in the history
  • Loading branch information
arknyazev committed Jan 17, 2025
1 parent b99da6e commit 0f923ca
Showing 1 changed file with 247 additions and 28 deletions.
275 changes: 247 additions & 28 deletions src/simsoptpp/tracing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,158 @@ using std::function;
#include "xtensor-python/pytensor.hpp" // Numpy bindings
typedef xt::pyarray<double> Array;

#include <boost/math/tools/roots.hpp>
#include <boost/numeric/odeint.hpp>
using boost::math::tools::toms748_solve;
using namespace boost::numeric::odeint;

#include <array>
#include <cmath>
#include <algorithm>

#include <gsl/gsl_vector.h>
#include <gsl/gsl_multiroots.h>
#include <iomanip>

double bisection_method(std::function<double(double)> f, double a, double b, double tol, int max_iter = 100) {
if (f(a) * f(b) >= 0) {
throw std::invalid_argument("f(a) and f(b) must have opposite signs");
}

double c = a;
for (int i = 0; i < max_iter; ++i) {
c = (a + b) / 2;
if (std::abs(f(c)) < tol || (b - a) / 2 < tol) {
return c;
}
if (f(c) * f(a) < 0) {
b = c;
} else {
a = c;
}
}
return c;
}

template <std::size_t Size>
class DormandPrinceIntegrator {
public:
using State = std::array<double, Size>;

DormandPrinceIntegrator(double abstol, double reltol, double dtmax)
: abstol_(abstol), reltol_(reltol), dtmax_(dtmax) {}

void initialize(const State& y, double t, double dt) {
y_ = y;
t_ = t;
dt_ = dt;
}

template <typename RHS>
std::tuple<double, double> do_step(RHS& rhs) {
// Butcher tableau coefficients for Dormand-Prince method
const double a21 = 1.0 / 5.0;
const double a31 = 3.0 / 40.0, a32 = 9.0 / 40.0;
const double a41 = 44.0 / 45.0, a42 = -56.0 / 15.0, a43 = 32.0 / 9.0;
const double a51 = 19372.0 / 6561.0, a52 = -25360.0 / 2187.0;
const double a53 = 64448.0 / 6561.0, a54 = -212.0 / 729.0;
const double a61 = 9017.0 / 3168.0, a62 = -355.0 / 33.0;
const double a63 = 46732.0 / 5247.0, a64 = 49.0 / 176.0;
const double a65 = -5103.0 / 18656.0;
const double b1 = 35.0 / 384.0, b3 = 500.0 / 1113.0;
const double b4 = 125.0 / 192.0, b5 = -2187.0 / 6784.0;
const double b6 = 11.0 / 84.0;
const double bhat1 = 5179.0 / 57600.0, bhat3 = 7571.0 / 16695.0;
const double bhat4 = 393.0 / 640.0, bhat5 = -92097.0 / 339200.0;
const double bhat6 = 187.0 / 2100.0, bhat7 = 1.0 / 40.0;

State k1, k2, k3, k4, k5, k6, k7;
State x_temp, x_new, x_err;

y_last_ = y_;
rhs(y_, k1, t_);
dy_last_ = k1;

for (int i = 0; i < Size; i++) {
x_temp[i] = y_[i] + dt_ * a21 * k1[i];
}
rhs(x_temp, k2, t_ + a21 * dt_);

for (int i = 0; i < Size; i++) {
x_temp[i] = y_[i] + dt_ * (a31 * k1[i] + a32 * k2[i]);
}
rhs(x_temp, k3, t_ + (a31 + a32) * dt_);

for (int i = 0; i < Size; i++) {
x_temp[i] = y_[i] + dt_ * (a41 * k1[i] + a42 * k2[i] + a43 * k3[i]);
}
rhs(x_temp, k4, t_ + (a41 + a42 + a43) * dt_);

for (int i = 0; i < Size; i++) {
x_temp[i] = y_[i] + dt_ * (a51 * k1[i] + a52 * k2[i] + a53 * k3[i] + a54 * k4[i]);
}
rhs(x_temp, k5, t_ + (a51 + a52 + a53 + a54) * dt_);

for (int i = 0; i < Size; i++) {
x_temp[i] = y_[i] + dt_ * (a61 * k1[i] + a62 * k2[i] + a63 * k3[i] + a64 * k4[i] + a65 * k5[i]);
}
rhs(x_temp, k6, t_ + (a61 + a62 + a63 + a64 + a65) * dt_);

for (int i = 0; i < Size; i++) {
x_new[i] = y_[i] + dt_ * (b1 * k1[i] + b3 * k3[i] + b4 * k4[i] + b5 * k5[i] + b6 * k6[i]);
}
rhs(x_new, k7, t_ + dt_);

double err = 0;
for (int i = 0; i < Size; i++) {
x_err[i] = dt_ * (bhat1 * k1[i] + bhat3 * k3[i] + bhat4 * k4[i] + bhat5 * k5[i] + bhat6 * k6[i] + bhat7 * k7[i]);
err = std::max(err, std::abs(x_err[i]));
}

double dt_new = 0.9 * dt_ * std::pow((abstol_ / err), 0.2);
dt_new = std::max(dt_new, 0.1 * dt_);
dt_new = std::min(dt_new, 5.0 * dt_);

if (err <= abstol_) {
t_ += dt_;
y_ = x_new;
dt_ = std::min(dt_new, dtmax_ - t_);
} else {
dt_ = dt_new;
}

y_current_ = y_;
dy_current_ = k1;
return std::make_tuple(t_ - dt_, t_);
}

double current_time() const {
return t_;
}

const State& current_state() const {
return y_;
}

void calc_state(double t, State& temp) {
double h = t_ - (t_ - dt_);
double s = (t - (t_ - dt_)) / h;
for (int i = 0; i < Size; i++) {
temp[i] = (1 - s) * y_last_[i] + s * y_current_[i] + s * (1 - s) * (
(1 - 2 * s) * (y_current_[i] - y_last_[i])
+ (s - 1) * h * dy_last_[i] + s * h * dy_current_[i]);
}
}

private:
double abstol_;
double reltol_;
double dtmax_;
double t_;
double dt_;
State y_;
State y_last_;
State y_current_;
State dy_last_;
State dy_current_;
};

template<template<class, std::size_t, xt::layout_type> class T>
class GuidingCenterVacuumBoozerRHS {
/*
Expand Down Expand Up @@ -147,7 +290,6 @@ class GuidingCenterNoKBoozerPerturbedRHS {
stzt(0, 1) = theta;
stzt(0, 2) = ys[2];
stzt(0, 3) = ys[4];

auto field = perturbed_field->get_B0();
perturbed_field->set_points(stzt);
auto psi0 = field->psi0;
Expand Down Expand Up @@ -175,6 +317,18 @@ class GuidingCenterNoKBoozerPerturbedRHS {
double dalphadzeta = perturbed_field->dalphadzeta_ref()(0);
double denom = (q*(G + I*(-alpha*dGdpsi + iota) + alpha*G*dIdpsi)
+ m*v_par/modB * (-dGdpsi*I + G*dIdpsi)); // q*G in vacuum

/* Debug begin */
// Debug print statements
std::cout << "s: " << s << ", theta: " << theta << ", ys[2]: " << ys[2] << ", ys[4]: " << ys[4] << std::endl;
std::cout << "psi0: " << psi0 << ", modB: " << modB << ", G: " << G << ", I: " << I << std::endl;
std::cout << "dGdpsi: " << dGdpsi << ", dIdpsi: " << dIdpsi << ", iota: " << iota << ", diotadpsi: " << diotadpsi << std::endl;
std::cout << "dmodBdpsi: " << dmodBdpsi << ", dmodBdtheta: " << dmodBdtheta << ", dmodBdzeta: " << dmodBdzeta << std::endl;
std::cout << "v_perp2: " << v_perp2 << ", fak1: " << fak1 << std::endl;
std::cout << "Phi: " << Phi << ", Phidot: " << Phidot << ", dPhidpsi: " << dPhidpsi << ", dPhidtheta: " << dPhidtheta << ", dPhidzeta: " << dPhidzeta << std::endl;
std::cout << "alpha: " << alpha << ", alphadot: " << alphadot << ", dalphadpsi: " << dalphadpsi << ", dalphadtheta: " << dalphadtheta << ", dalphadzeta: " << dalphadzeta << std::endl;
std::cout << "denom: " << denom << std::endl;
/* Debug end */

double sdot = (-G*dPhidtheta*q + I*dPhidzeta*q + modB*q*v_par*(dalphadtheta*G-dalphadzeta*I) + (-dmodBdtheta*G + dmodBdzeta*I)*fak1)/(denom*psi0);
double tdot = (G*q*dPhidpsi + modB*q*v_par*(-dalphadpsi*G - alpha*dGdpsi + iota) - dGdpsi*m*v_par*v_par \
Expand Down Expand Up @@ -211,6 +365,12 @@ class GuidingCenterNoKBoozerPerturbedRHS {
+ v_par/modB * (dmodBdtheta*dPhidpsi - dmodBdpsi*dPhidtheta);
*/
dydt[4] = 1;
/* Debug begin */
// Print dydt values
std::cout << "dydt[0]: " << dydt[0] << ", dydt[1]: " << dydt[1] << ", dydt[2]: " << dydt[2] << ", dydt[3]: " << dydt[3] << ", dydt[4]: " << dydt[4] << std::endl;
std::cout << "Press ENTER (hi) to continue..." << std::endl;

/* debug end */
}
};

Expand Down Expand Up @@ -397,24 +557,39 @@ std::array<double, m+n> join(const std::array<double, m>& a, const std::array<do

template<class RHS>
tuple<vector<array<double, RHS::Size+1>>, vector<array<double, RHS::Size+2>>>
solve(RHS rhs, typename RHS::State y, double tmax, double dt, double dtmax, double abstol, double reltol, vector<double> zetas, vector<double> omegas, vector<shared_ptr<StoppingCriterion>> stopping_criteria, double dt_save, vector<double> vpars, bool zetas_stop=false, bool vpars_stop=false, bool forget_exact_path=false) {
solve(
RHS rhs,
typename RHS::State y,
double tmax,
double dt,
double dtmax,
double abstol,
double reltol,
vector<double> zetas,
vector<double> omegas,
vector<shared_ptr<StoppingCriterion>> stopping_criteria,
double dt_save,
vector<double> vpars,
bool zetas_stop=false,
bool vpars_stop=false,
bool forget_exact_path=false) {

if (zetas.size() > 0 && omegas.size() == 0) {
omegas.insert(omegas.end(), zetas.size(), 0.);
} else if (zetas.size() != omegas.size()) {
throw std::invalid_argument("zetas and omegas need to have matching length.");
throw std::invalid_argument(
"zetas and omegas need to have matching length."
);
}

vector<array<double, RHS::Size+1>> res = {};
vector<array<double, RHS::Size+2>> res_hits = {};
// array<double, RHS::Size> ykeep = {};
typedef typename RHS::State State;
State temp;
State ykeep;
typedef typename boost::numeric::odeint::result_of::make_dense_output<runge_kutta_dopri5<State>>::type dense_stepper_type;
dense_stepper_type dense = make_dense_output(abstol, reltol, dtmax, runge_kutta_dopri5<State>());
DormandPrinceIntegrator<RHS::Size> integrator(abstol, reltol, dtmax);
double t = 0;
dense.initialize(y, t, dt);
integrator.initialize(y, 0, dt);
int iter = 0;
bool stop = false;
double zeta_last;
Expand All @@ -437,17 +612,49 @@ solve(RHS rhs, typename RHS::State y, double tmax, double dt, double dtmax, doub

double zeta_current, vpar_current, t_current;
do {
tuple<double, double> step = dense.do_step(rhs);
auto step = integrator.do_step(rhs);
iter++;
t = dense.current_time();
y = dense.current_state();
t = integrator.current_time();
y = integrator.current_state();
zeta_current = y[2];
vpar_current = y[3];
double t_last = std::get<0>(step);
double t_current = std::get<1>(step);
dt = t_current - t_last;
stop = check_stopping_criteria(rhs, y, iter, res, res_hits, dense, t_last, t_current, zeta_last, zeta_current, vpar_last,
vpar_current, abstol, zetas, omegas, stopping_criteria, vpars, zetas_stop, vpars_stop, forget_exact_path, dt_save);

// Debug print statements
std::cout << "Iteration: " << iter << ", t: " << t << ", dt: " << dt << std::endl;
std::cout << "State: ";
for (const auto& val : y) {
std::cout << val << " ";
}
std::cout << std::endl;
std::cin.ignore();
// end debug print statements

stop = check_stopping_criteria(
rhs,
y,
iter,
res,
res_hits,
integrator,
t_last,
t_current,
zeta_last,
zeta_current,
vpar_last,
vpar_current,
abstol,
zetas,
omegas,
stopping_criteria,
vpars,
zetas_stop,
vpars_stop,
forget_exact_path,
dt_save
);
zeta_last = zeta_current;
vpar_last = vpar_current;
} while(t < tmax && !stop);
Expand All @@ -466,12 +673,30 @@ solve(RHS rhs, typename RHS::State y, double tmax, double dt, double dtmax, doub
}

template<class RHS, class DENSE>
bool check_stopping_criteria(RHS rhs, typename RHS::State y, int iter, vector<array<double, RHS::Size+1>> &res, vector<array<double, RHS::Size+2>> &res_hits, DENSE dense, double t_last, double t_current, double zeta_last, double zeta_current, double vpar_last, double vpar_current, double abstol, vector<double> zetas, vector<double> omegas, vector<shared_ptr<StoppingCriterion>> stopping_criteria, vector<double> vpars, bool zetas_stop, bool vpars_stop, bool forget_exact_path, double dt_save) {
bool check_stopping_criteria(
RHS rhs,
typename RHS::State y,
int iter,
vector<array<double, RHS::Size+1>> &res,
vector<array<double, RHS::Size+2>> &res_hits,
DENSE dense,
double t_last,
double t_current,
double zeta_last,
double zeta_current,
double vpar_last,
double vpar_current,
double abstol,
vector<double> zetas,
vector<double> omegas,
vector<shared_ptr<StoppingCriterion>> stopping_criteria,
vector<double> vpars,
bool zetas_stop,
bool vpars_stop,
bool forget_exact_path,
double dt_save) {

typedef typename RHS::State State;
// abstol?
boost::math::tools::eps_tolerance<double> roottol(-int(std::log2(abstol)));
uintmax_t rootmaxit = 200;
State temp;

bool stop = false;
Expand Down Expand Up @@ -506,10 +731,7 @@ bool check_stopping_criteria(RHS rhs, typename RHS::State y, int iter, vector<ar
dense.calc_state(t, temp);
return temp[3]-vpar;
};
auto root = toms748_solve(rootfun, t_last, t_current, vpar_last-vpar, vpar_current-vpar, roottol, rootmaxit);
double f0 = rootfun(root.first);
double f1 = rootfun(root.second);
double troot = std::abs(f0) < std::abs(f1) ? root.first : root.second;
double troot = bisection_method(rootfun, t_last, t_current, abstol);
dense.calc_state(troot, temp);
ykeep = temp;
if (rhs.axis==1) {
Expand Down Expand Up @@ -541,10 +763,7 @@ bool check_stopping_criteria(RHS rhs, typename RHS::State y, int iter, vector<ar
dense.calc_state(t, temp);
return temp[2] - omega*t - phase_shift;
};
auto root = toms748_solve(rootfun, t_last, t_current, phase_last - phase_shift, phase_current - phase_shift, roottol, rootmaxit);
double f0 = rootfun(root.first);
double f1 = rootfun(root.second);
double troot = std::abs(f0) < std::abs(f1) ? root.first : root.second;
double troot = bisection_method(rootfun, t_last, t_current, abstol);
dense.calc_state(troot, temp);
ykeep = temp;
if (rhs.axis==1) {
Expand Down

0 comments on commit 0f923ca

Please sign in to comment.