Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable open upper limits for the grids in interpolation code #440

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/simsopt/field/magneticfieldclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,10 +862,14 @@ def __init__(self, field, degree, rrange, phirange, zrange, extrapolate=True, nf
Args:
field: the underlying :mod:`simsopt.field.magneticfield.MagneticField` to be interpolated.
degree: the degree of the piecewise polynomial interpolant.
rrange: a 3-tuple of the form ``(rmin, rmax, nr)``. This mean that the interval :math:`[rmin, rmax]` is
split into ``nr`` many subintervals.
phirange: a 3-tuple of the form ``(phimin, phimax, nphi)``.
zrange: a 3-tuple of the form ``(zmin, zmax, nz)``.
rrange: a 3-tuple of the form ``(rmin, rmax, nr)`` or a 4-tuple of the form ``(rmin, rmax, nr, include_endpoint)``.
The default for `include_endpoint` is True. This mean that the interval :math:`[rmin, rmax]` is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo, this means

split into ``nr`` many subintervals. When `include_endpoint` is False, interval :math:`[rmin, rmax)` is
split into ``nr`` many subintervals.
Copy link
Contributor

@andrewgiuliani andrewgiuliani Jul 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does it mean [rmin, rmax] is split into nr many subintervals as opposed to [rmin, rmax) is split into nr many subintervals ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To concretize this, lets suppose you want 5 points between 0 and 1. If you include the end point, the mesh would [0, 0.25, 0.5, 0.75, 1]. If you don't want to include the end point, the mesh generated would be [0, 0.2, 0.4, 0.6, 0.8]. So the mesh would be quite different. Unfortunately, BMW generates phi grid with end point not included.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then what happens on the small subinterval (2pi/nfp - h, 2pi/nfp)? I guess the behavior would be dictated by the extrapolate keyword argument.

Since the BMW field is nfp-periodic, you can just rotate the B-field values from B(R, phi=0, Z) to B(R, phi=2pi/nfp, Z) and obtain the endpoint value.

phirange: a 3-tuple of the form ``(phimin, phimax, nphi)`` or a 4-tuple of the form
``(phimin, phimax, nphi, include_endpoint)``.
zrange: a 3-tuple of the form ``(zmin, zmax, nz)`` or a 4-tuple of the form
``(zmin, zmax, nz, include_endpoint)``.
extrapolate: whether to extrapolate the field when evaluate outside
the integration domain or to throw an error.
nfp: Whether to exploit rotational symmetry. In this case any angle
Expand Down Expand Up @@ -900,7 +904,7 @@ def skip(xs, ys, zs):
return [False for _ in xs]

sopp.InterpolatedField.__init__(self, field, degree, rrange, phirange, zrange, extrapolate, nfp, stellsym, skip)
self.__field = field
self._field = field

def to_vtk(self, filename):
"""Export the field evaluated on a regular grid for visualisation with e.g. Paraview."""
Expand Down
24 changes: 22 additions & 2 deletions src/simsoptpp/magneticfield_interpolated.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,27 @@ class InterpolatedField : public MagneticField<T> {

public:
const shared_ptr<MagneticField<T>> field;
const RangeTriplet r_range, phi_range, z_range;
// const RangeTriplet r_range, phi_range, z_range;
const RangeParams r_range, phi_range, z_range;
using MagneticField<T>::npoints;
const InterpolationRule rule;

InterpolatedField(
shared_ptr<MagneticField<T>> field, InterpolationRule rule,
RangeTriplet r_range, RangeTriplet phi_range, RangeTriplet z_range,
bool extrapolate, int nfp, bool stellsym, std::function<std::vector<bool>(Vec, Vec, Vec)> skip) :
InterpolatedField(field, rule,
std::make_tuple(std::get<0>(r_range), std::get<1>(r_range), std::get<2>(r_range), true),
std::make_tuple(std::get<0>(phi_range), std::get<1>(phi_range), std::get<2>(phi_range), true),
std::make_tuple(std::get<0>(z_range), std::get<1>(z_range), std::get<2>(z_range), true),
extrapolate, nfp, stellsym, skip) {}

InterpolatedField(
shared_ptr<MagneticField<T>> field, InterpolationRule rule,
RangeParams r_range, RangeParams phi_range, RangeParams z_range,
bool extrapolate, int nfp, bool stellsym, std::function<std::vector<bool>(Vec, Vec, Vec)> skip) :
field(field), rule(rule), r_range(r_range), phi_range(phi_range), z_range(z_range), extrapolate(extrapolate), nfp(nfp), stellsym(stellsym),
skip(skip)

{
fbatch_B = [this](Vec r, Vec phi, Vec z) {
int npoints = r.size();
Expand Down Expand Up @@ -184,6 +194,16 @@ class InterpolatedField : public MagneticField<T> {
InterpolatedField(
shared_ptr<MagneticField<T>> field, int degree,
RangeTriplet r_range, RangeTriplet phi_range, RangeTriplet z_range,
bool extrapolate, int nfp, bool stellsym, std::function<std::vector<bool>(Vec, Vec, Vec)> skip) :
InterpolatedField(field, UniformInterpolationRule(degree),
std::make_tuple(std::get<0>(r_range), std::get<1>(r_range), std::get<2>(r_range), true),
std::make_tuple(std::get<0>(phi_range), std::get<1>(phi_range), std::get<2>(phi_range), true),
std::make_tuple(std::get<0>(z_range), std::get<1>(z_range), std::get<2>(z_range), true),
extrapolate, nfp, stellsym, skip) {}

InterpolatedField(
shared_ptr<MagneticField<T>> field, int degree,
RangeParams r_range, RangeParams phi_range, RangeParams z_range,
bool extrapolate, int nfp, bool stellsym, std::function<std::vector<bool>(Vec, Vec, Vec)> skip) : InterpolatedField(field, UniformInterpolationRule(degree), r_range, phi_range, z_range, extrapolate, nfp, stellsym, skip) {}

std::pair<double, double> estimate_error_B(int samples) {
Expand Down
2 changes: 2 additions & 0 deletions src/simsoptpp/python_magneticfield.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ void init_magneticfields(py::module_ &m){
auto ifield = py::class_<PyInterpolatedField, shared_ptr<PyInterpolatedField>, PyMagneticField>(m, "InterpolatedField")
.def(py::init<shared_ptr<PyMagneticField>, InterpolationRule, RangeTriplet, RangeTriplet, RangeTriplet, bool, int, bool, std::function<std::vector<bool>(Vec, Vec, Vec)>>())
.def(py::init<shared_ptr<PyMagneticField>, int, RangeTriplet, RangeTriplet, RangeTriplet, bool, int, bool, std::function<std::vector<bool>(Vec, Vec, Vec)>>())
.def(py::init<shared_ptr<PyMagneticField>, InterpolationRule, RangeParams, RangeParams, RangeParams, bool, int, bool, std::function<std::vector<bool>(Vec, Vec, Vec)>>())
.def(py::init<shared_ptr<PyMagneticField>, int, RangeParams, RangeParams, RangeParams, bool, int, bool, std::function<std::vector<bool>(Vec, Vec, Vec)>>())
.def("estimate_error_B", &PyInterpolatedField::estimate_error_B)
.def("estimate_error_GradAbsB", &PyInterpolatedField::estimate_error_GradAbsB)
.def_readonly("r_range", &PyInterpolatedField::r_range)
Expand Down
40 changes: 27 additions & 13 deletions src/simsoptpp/regular_grid_interpolant_3d.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

using Vec = std::vector<double>;
using RangeTriplet = std::tuple<double, double, int>;
using RangeParams = std::tuple<double, double, int, bool>;

Vec linspace(double min, double max, int n, bool endpoint);
double linspace(double min, double max, int n, bool endpoint, Vec& res);

class InterpolationRule {
/* An InterpolationRule consists of a list of interpolation nodes and then
Expand Down Expand Up @@ -86,7 +87,8 @@ class RegularGridInterpolant3D {
const int nx, ny, nz; // number of cells in x, y, and z direction
double hx, hy, hz; // gridsize in x, y, and z direction
const double xmin, ymin, zmin; // lower bounds of the x, y, and z coordinates
const double xmax, ymax, zmax; // lower bounds of the x, y, and z coordinates
const double xmax, ymax, zmax; // Upper bounds of the x, y, and z coordinates
const bool x_endpoint, y_endpoint, z_endpoint; // Include upper bound endpoint or not
const int value_size; // number of output dimensions of the interpolant, i.e. space that is mapped into
const InterpolationRule rule; // the interpolation rule to use on each cell in the grid
const bool out_of_bounds_ok; // whether to do nothing or throw an error when the interpolant is queried at an out-of-bounds point
Expand Down Expand Up @@ -144,26 +146,38 @@ class RegularGridInterpolant3D {
void evaluate_local(double x, double y, double z, int cell_idx, double* res);

public:

RegularGridInterpolant3D(InterpolationRule rule, RangeTriplet xrange, RangeTriplet yrange, RangeTriplet zrange, int value_size, bool out_of_bounds_ok, std::function<std::vector<bool>(Vec, Vec, Vec)> skip) :
RegularGridInterpolant3D(InterpolationRule rule, RangeTriplet xrange, RangeTriplet yrange, RangeTriplet zrange, int value_size, bool out_of_bounds_ok, std::function<std::vector<bool>(Vec, Vec, Vec)> skip) :
RegularGridInterpolant3D(
rule,
std::make_tuple(std::get<0>(xrange), std::get<1>(xrange), std::get<2>(xrange), true), // Default is to include endpoint
std::make_tuple(std::get<0>(yrange), std::get<1>(yrange), std::get<2>(yrange), true), // Default is to include endpoint
std::make_tuple(std::get<0>(zrange), std::get<1>(zrange), std::get<2>(zrange), true), // Default is to include endpoint
value_size,
out_of_bounds_ok,
skip
)
{ }

RegularGridInterpolant3D(InterpolationRule rule, RangeParams xrange, RangeParams yrange, RangeParams zrange, int value_size, bool out_of_bounds_ok, std::function<std::vector<bool>(Vec, Vec, Vec)> skip) :
rule(rule),
xmin(std::get<0>(xrange)), xmax(std::get<1>(xrange)), nx(std::get<2>(xrange)),
ymin(std::get<0>(yrange)), ymax(std::get<1>(yrange)), ny(std::get<2>(yrange)),
zmin(std::get<0>(zrange)), zmax(std::get<1>(zrange)), nz(std::get<2>(zrange)),
xmin(std::get<0>(xrange)), xmax(std::get<1>(xrange)), nx(std::get<2>(xrange)), x_endpoint(std::get<3>(xrange)),
ymin(std::get<0>(yrange)), ymax(std::get<1>(yrange)), ny(std::get<2>(yrange)), y_endpoint(std::get<3>(yrange)),
zmin(std::get<0>(zrange)), zmax(std::get<1>(zrange)), nz(std::get<2>(zrange)), z_endpoint(std::get<3>(zrange)),
value_size(value_size), out_of_bounds_ok(out_of_bounds_ok)
{
int degree = rule.degree;
pkxs = Vec(degree+1, 0.);
pkys = Vec(degree+1, 0.);
pkzs = Vec(degree+1, 0.);
hx = (xmax-xmin)/nx;
hy = (ymax-ymin)/ny;
hz = (zmax-zmin)/nz;

// build a regular mesh on [xmin, xmax] x [ymin, ymax] x [zmin, zmax]
xmesh = linspace(xmin, xmax, nx+1, true);
ymesh = linspace(ymin, ymax, ny+1, true);
zmesh = linspace(zmin, zmax, nz+1, true);
xmesh.reserve(nx+1);
ymesh.reserve(ny+1);
zmesh.reserve(nz+1);

hx = linspace(xmin, xmax, nx+1, x_endpoint, xmesh);
hy = linspace(ymin, ymax, ny+1, y_endpoint, ymesh);
hz = linspace(zmin, zmax, nz+1, z_endpoint, zmesh);

int nmesh = (nx+1)*(ny+1)*(nz+1);
Vec xmeshtensor(nmesh, 0.);
Expand Down
10 changes: 5 additions & 5 deletions src/simsoptpp/regular_grid_interpolant_3d_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,16 +237,16 @@ std::pair<double, double> RegularGridInterpolant3D<Array>::estimate_error(std::f



Vec linspace(double min, double max, int n, bool endpoint) {
Vec res(n, 0.);
double linspace(double min, double max, int n, bool endpoint, Vec& res) {
double h;
if(endpoint) {
double h = (max-min)/(n-1);
h = (max-min)/(n-1);
for (int i = 0; i < n; ++i)
res[i] = min + i*h;
} else {
double h = (max-min)/n;
h = (max-min)/n;
for (int i = 0; i < n; ++i)
res[i] = min + i*h;
}
return res;
return h;
}
Loading