-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Using classes in order to override methods
- Loading branch information
Showing
5 changed files
with
144 additions
and
172 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,106 +1,101 @@ | ||
from functools import partial | ||
|
||
from typing import NamedTuple | ||
import jax | ||
import jax.numpy as jnp | ||
from jax import Array | ||
from jaxfun.utils.common import jacn | ||
from jaxfun.Jacobi import Jacobi, Domain | ||
import sympy as sp | ||
|
||
n = sp.Symbol('n', integer=True, positive=True) | ||
|
||
# Jacobi constants | ||
alpha, beta = -sp.S.Half, -sp.S.Half | ||
|
||
# Scaling function (see Eq. (2.28) of https://www.duo.uio.no/bitstream/handle/10852/99687/1/PGpaper.pdf) | ||
def gn(alpha, beta, n): | ||
return sp.S(1)/sp.jacobi(n, alpha, beta, 1) | ||
n = sp.Symbol("n", integer=True, positive=True) | ||
|
||
@jax.jit | ||
def evaluate(x: float, c: Array) -> float: | ||
""" | ||
Evaluate a Chebyshev series at points x. | ||
|
||
.. math:: p(x) = c_0 * T_0(x) + c_1 * T_1(x) + ... + c_n * T_n(x) | ||
class Chebyshev(Jacobi): | ||
|
||
Parameters | ||
---------- | ||
x : float | ||
c : Array | ||
def __init__(self, N: int, domain: NamedTuple = Domain(-1, 1), **kw): | ||
Jacobi.__init__(self, N, domain, -sp.S.Half, -sp.S.Half) | ||
|
||
Returns | ||
------- | ||
values : Array | ||
|
||
Notes | ||
----- | ||
The evaluation uses Clenshaw recursion, aka synthetic division. | ||
# Scaling function (see Eq. (2.28) of https://www.duo.uio.no/bitstream/handle/10852/99687/1/PGpaper.pdf) | ||
@staticmethod | ||
def gn(alpha, beta, n): | ||
return sp.S(1) / sp.jacobi(n, alpha, beta, 1) | ||
|
||
""" | ||
if len(c) == 1: | ||
# Multiply by 0 * x for shape | ||
return c[0] + 0 * x | ||
if len(c) == 2: | ||
return c[0] + c[1] * x | ||
|
||
def body_fun(i: int, val: tuple[Array, Array]) -> tuple[Array, Array]: | ||
c0, c1 = val | ||
@partial(jax.jit, static_argnums=0) | ||
def evaluate(self, x: float, c: Array) -> float: | ||
""" | ||
Evaluate a Chebyshev series at points x. | ||
tmp = c0 | ||
c0 = c[-i] - c1 | ||
c1 = tmp + c1 * 2 * x | ||
.. math:: p(x) = c_0 * L_0(x) + c_1 * L_1(x) + ... + c_n * L_n(x) | ||
return c0, c1 | ||
Parameters | ||
---------- | ||
x : float | ||
c : Array | ||
c0 = jnp.ones_like(x) * c[-2] | ||
c1 = jnp.ones_like(x) * c[-1] | ||
Returns | ||
------- | ||
values : Array | ||
c0, c1 = jax.lax.fori_loop(3, len(c) + 1, body_fun, (c0, c1)) | ||
return c0 + c1 * x | ||
Notes | ||
----- | ||
The evaluation uses Clenshaw recursion, aka synthetic division. | ||
""" | ||
if len(c) == 1: | ||
# Multiply by 0 * x for shape | ||
return c[0] + 0 * x | ||
if len(c) == 2: | ||
return c[0] + c[1] * x | ||
|
||
def quad_points_and_weights(N: int) -> Array: | ||
return jnp.array( | ||
( | ||
jnp.cos(jnp.pi + (2 * jnp.arange(N) + 1) * jnp.pi / (2 * N)), | ||
jnp.ones(N) * jnp.pi / N, | ||
) | ||
) | ||
|
||
def body_fun(i: int, val: tuple[Array, Array]) -> tuple[Array, Array]: | ||
c0, c1 = val | ||
|
||
@partial(jax.jit, static_argnums=(1, 2)) | ||
def evaluate_basis_derivative(x: Array, deg: int, k: int = 0) -> Array: | ||
return jacn(eval_basis_functions, k)(x, deg) | ||
tmp = c0 | ||
c0 = c[-i] - c1 | ||
c1 = tmp + c1 * 2 * x | ||
|
||
return c0, c1 | ||
|
||
@partial(jax.jit, static_argnums=1) | ||
def vandermonde(x: Array, deg: int) -> Array: | ||
return evaluate_basis_derivative(x, deg, 0) | ||
c0 = jnp.ones_like(x) * c[-2] | ||
c1 = jnp.ones_like(x) * c[-1] | ||
|
||
c0, c1 = jax.lax.fori_loop(3, len(c) + 1, body_fun, (c0, c1)) | ||
return c0 + c1 * x | ||
|
||
@partial(jax.jit, static_argnums=1) | ||
def eval_basis_function(x: float, i: int) -> Array: | ||
# return jnp.cos(i * jnp.acos(x)) | ||
x0 = x * 0 + 1 | ||
if i == 0: | ||
return x0 | ||
def quad_points_and_weights(self, N: int) -> Array: | ||
return jnp.array( | ||
( | ||
jnp.cos(jnp.pi + (2 * jnp.arange(N) + 1) * jnp.pi / (2 * N)), | ||
jnp.ones(N) * jnp.pi / N, | ||
) | ||
) | ||
|
||
def body_fun(i: int, val: tuple[Array, Array]) -> tuple[Array, Array]: | ||
x0, x1 = val | ||
x2 = 2 * x * x1 - x0 | ||
return x1, x2 | ||
@partial(jax.jit, static_argnums=(0, 2)) | ||
def eval_basis_function(self, x: float, i: int) -> float: | ||
# return jnp.cos(i * jnp.acos(x)) | ||
x0 = x * 0 + 1 | ||
if i == 0: | ||
return x0 | ||
|
||
return jax.lax.fori_loop(1, i, body_fun, (x0, x))[-1] | ||
def body_fun(i: int, val: tuple[Array, Array]) -> tuple[Array, Array]: | ||
x0, x1 = val | ||
x2 = 2 * x * x1 - x0 | ||
return x1, x2 | ||
|
||
return jax.lax.fori_loop(1, i, body_fun, (x0, x))[-1] | ||
|
||
@partial(jax.jit, static_argnums=1) | ||
def eval_basis_functions(x: float, deg: int) -> Array: | ||
x0 = x * 0 + 1 | ||
@partial(jax.jit, static_argnums=(0, 2)) | ||
def eval_basis_functions(self, x: float, deg: int) -> Array: | ||
x0 = x * 0 + 1 | ||
|
||
def inner_loop(carry: tuple[float, float], _) -> tuple[tuple[float, float], Array]: | ||
x0, x1 = carry | ||
x2 = 2 * x * x1 - x0 | ||
return (x1, x2), x1 | ||
def inner_loop( | ||
carry: tuple[float, float], _ | ||
) -> tuple[tuple[float, float], Array]: | ||
x0, x1 = carry | ||
x2 = 2 * x * x1 - x0 | ||
return (x1, x2), x1 | ||
|
||
_, xs = jax.lax.scan(inner_loop, init=(x0, x), xs=None, length=deg - 1) | ||
_, xs = jax.lax.scan(inner_loop, init=(x0, x), xs=None, length=deg - 1) | ||
|
||
return jnp.hstack((x0, xs)) | ||
return jnp.hstack((x0, xs)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.