Skip to content

Commit

Permalink
Using classes in order to override methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mikaem committed Dec 14, 2024
1 parent 91e4909 commit 3a3d53b
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 172 deletions.
6 changes: 3 additions & 3 deletions examples/poisson.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import matplotlib.pyplot as plt
import sympy as sp
import jax.numpy as jnp
from jaxfun import Chebyshev
from jaxfun.Legendre import Legendre as space
from jaxfun.composite import Composite
from jaxfun.inner import inner

Expand All @@ -13,13 +13,13 @@
f = ue.diff(s, 2)
N = 50
bcs = {'left': {'D': 0}, 'right': {'D': 0}}
C = Composite(Chebyshev, N, bcs)
C = Composite(space, N, bcs)
v = (C, 0)
u = (C, 2)
A = inner(v, u, sparse=True)
b = inner(v, f)
u = jnp.linalg.solve(A.todense(), b)
x = jnp.linspace(-1, 1, 100)
plt.plot(x, sp.lambdify(s, ue)(x), "r")
plt.plot(x, Chebyshev.evaluate(x, u @ C.S), "b") # u @ C.S return coefficients in the orthogonal basis
plt.plot(x, C.orthogonal.evaluate(x, u @ C.S), "b") # u @ C.S return coefficients in the orthogonal basis
plt.show()
145 changes: 70 additions & 75 deletions jaxfun/Chebyshev.py
Original file line number Diff line number Diff line change
@@ -1,106 +1,101 @@
from functools import partial

from typing import NamedTuple
import jax
import jax.numpy as jnp
from jax import Array
from jaxfun.utils.common import jacn
from jaxfun.Jacobi import Jacobi, Domain
import sympy as sp

n = sp.Symbol('n', integer=True, positive=True)

# Jacobi constants
alpha, beta = -sp.S.Half, -sp.S.Half

# Scaling function (see Eq. (2.28) of https://www.duo.uio.no/bitstream/handle/10852/99687/1/PGpaper.pdf)
def gn(alpha, beta, n):
return sp.S(1)/sp.jacobi(n, alpha, beta, 1)
n = sp.Symbol("n", integer=True, positive=True)

@jax.jit
def evaluate(x: float, c: Array) -> float:
"""
Evaluate a Chebyshev series at points x.

.. math:: p(x) = c_0 * T_0(x) + c_1 * T_1(x) + ... + c_n * T_n(x)
class Chebyshev(Jacobi):

Parameters
----------
x : float
c : Array
def __init__(self, N: int, domain: NamedTuple = Domain(-1, 1), **kw):
Jacobi.__init__(self, N, domain, -sp.S.Half, -sp.S.Half)

Returns
-------
values : Array

Notes
-----
The evaluation uses Clenshaw recursion, aka synthetic division.
# Scaling function (see Eq. (2.28) of https://www.duo.uio.no/bitstream/handle/10852/99687/1/PGpaper.pdf)
@staticmethod
def gn(alpha, beta, n):
return sp.S(1) / sp.jacobi(n, alpha, beta, 1)

"""
if len(c) == 1:
# Multiply by 0 * x for shape
return c[0] + 0 * x
if len(c) == 2:
return c[0] + c[1] * x

def body_fun(i: int, val: tuple[Array, Array]) -> tuple[Array, Array]:
c0, c1 = val
@partial(jax.jit, static_argnums=0)
def evaluate(self, x: float, c: Array) -> float:
"""
Evaluate a Chebyshev series at points x.
tmp = c0
c0 = c[-i] - c1
c1 = tmp + c1 * 2 * x
.. math:: p(x) = c_0 * L_0(x) + c_1 * L_1(x) + ... + c_n * L_n(x)
return c0, c1
Parameters
----------
x : float
c : Array
c0 = jnp.ones_like(x) * c[-2]
c1 = jnp.ones_like(x) * c[-1]
Returns
-------
values : Array
c0, c1 = jax.lax.fori_loop(3, len(c) + 1, body_fun, (c0, c1))
return c0 + c1 * x
Notes
-----
The evaluation uses Clenshaw recursion, aka synthetic division.
"""
if len(c) == 1:
# Multiply by 0 * x for shape
return c[0] + 0 * x
if len(c) == 2:
return c[0] + c[1] * x

def quad_points_and_weights(N: int) -> Array:
return jnp.array(
(
jnp.cos(jnp.pi + (2 * jnp.arange(N) + 1) * jnp.pi / (2 * N)),
jnp.ones(N) * jnp.pi / N,
)
)

def body_fun(i: int, val: tuple[Array, Array]) -> tuple[Array, Array]:
c0, c1 = val

@partial(jax.jit, static_argnums=(1, 2))
def evaluate_basis_derivative(x: Array, deg: int, k: int = 0) -> Array:
return jacn(eval_basis_functions, k)(x, deg)
tmp = c0
c0 = c[-i] - c1
c1 = tmp + c1 * 2 * x

return c0, c1

@partial(jax.jit, static_argnums=1)
def vandermonde(x: Array, deg: int) -> Array:
return evaluate_basis_derivative(x, deg, 0)
c0 = jnp.ones_like(x) * c[-2]
c1 = jnp.ones_like(x) * c[-1]

c0, c1 = jax.lax.fori_loop(3, len(c) + 1, body_fun, (c0, c1))
return c0 + c1 * x

@partial(jax.jit, static_argnums=1)
def eval_basis_function(x: float, i: int) -> Array:
# return jnp.cos(i * jnp.acos(x))
x0 = x * 0 + 1
if i == 0:
return x0
def quad_points_and_weights(self, N: int) -> Array:
return jnp.array(
(
jnp.cos(jnp.pi + (2 * jnp.arange(N) + 1) * jnp.pi / (2 * N)),
jnp.ones(N) * jnp.pi / N,
)
)

def body_fun(i: int, val: tuple[Array, Array]) -> tuple[Array, Array]:
x0, x1 = val
x2 = 2 * x * x1 - x0
return x1, x2
@partial(jax.jit, static_argnums=(0, 2))
def eval_basis_function(self, x: float, i: int) -> float:
# return jnp.cos(i * jnp.acos(x))
x0 = x * 0 + 1
if i == 0:
return x0

return jax.lax.fori_loop(1, i, body_fun, (x0, x))[-1]
def body_fun(i: int, val: tuple[Array, Array]) -> tuple[Array, Array]:
x0, x1 = val
x2 = 2 * x * x1 - x0
return x1, x2

return jax.lax.fori_loop(1, i, body_fun, (x0, x))[-1]

@partial(jax.jit, static_argnums=1)
def eval_basis_functions(x: float, deg: int) -> Array:
x0 = x * 0 + 1
@partial(jax.jit, static_argnums=(0, 2))
def eval_basis_functions(self, x: float, deg: int) -> Array:
x0 = x * 0 + 1

def inner_loop(carry: tuple[float, float], _) -> tuple[tuple[float, float], Array]:
x0, x1 = carry
x2 = 2 * x * x1 - x0
return (x1, x2), x1
def inner_loop(
carry: tuple[float, float], _
) -> tuple[tuple[float, float], Array]:
x0, x1 = carry
x2 = 2 * x * x1 - x0
return (x1, x2), x1

_, xs = jax.lax.scan(inner_loop, init=(x0, x), xs=None, length=deg - 1)
_, xs = jax.lax.scan(inner_loop, init=(x0, x), xs=None, length=deg - 1)

return jnp.hstack((x0, xs))
return jnp.hstack((x0, xs))
117 changes: 50 additions & 67 deletions jaxfun/Legendre.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,96 +3,79 @@
import jax
import jax.numpy as jnp
from jax import Array
from jaxfun.utils.common import jacn
from jaxfun.utils.fastgl import leggauss
from jaxfun.Jacobi import Jacobi, Domain, NamedTuple
import sympy as sp

n = sp.Symbol("n", integer=True, positive=True)

# Jacobi constants
alpha, beta = 0, 0

class Legendre(Jacobi):

# Scaling function (see Eq. (2.28) of https://www.duo.uio.no/bitstream/handle/10852/99687/1/PGpaper.pdf)
def gn(alpha, beta, n):
return 1
def __init__(self, N: int, domain: NamedTuple = Domain(-1, 1), **kw):
Jacobi.__init__(self, N, domain, 0, 0)


@jax.jit
def evaluate(x: float, c: Array) -> float:
"""
Evaluate a Legendre series at points x.
@partial(jax.jit, static_argnums=0)
def evaluate(self, x: float, c: Array) -> float:
"""
Evaluate a Legendre series at points x.
.. math:: p(x) = c_0 * L_0(x) + c_1 * L_1(x) + ... + c_n * L_n(x)
.. math:: p(x) = c_0 * L_0(x) + c_1 * L_1(x) + ... + c_n * L_n(x)
Parameters
----------
x : float
c : Array
Parameters
----------
x : float
c : Array
Returns
-------
values : Array
Returns
-------
values : Array
Notes
-----
The evaluation uses Clenshaw recursion, aka synthetic division.
Notes
-----
The evaluation uses Clenshaw recursion, aka synthetic division.
"""
if len(c) == 1:
# Multiply by 0 * x for shape
return c[0] + 0 * x
if len(c) == 2:
return c[0] + c[1] * x
"""
if len(c) == 1:
# Multiply by 0 * x for shape
return c[0] + 0 * x
if len(c) == 2:
return c[0] + c[1] * x

def body_fun(i: int, val: tuple[int, Array, Array]) -> tuple[int, Array, Array]:
nd, c0, c1 = val
def body_fun(i: int, val: tuple[int, Array, Array]) -> tuple[int, Array, Array]:
nd, c0, c1 = val

tmp = c0
nd = nd - 1
c0 = c[-i] - (c1 * (nd - 1)) / nd
c1 = tmp + (c1 * x * (2 * nd - 1)) / nd
tmp = c0
nd = nd - 1
c0 = c[-i] - (c1 * (nd - 1)) / nd
c1 = tmp + (c1 * x * (2 * nd - 1)) / nd

return nd, c0, c1
return nd, c0, c1

nd = len(c)
c0 = jnp.ones_like(x) * c[-2]
c1 = jnp.ones_like(x) * c[-1]
nd = len(c)
c0 = jnp.ones_like(x) * c[-2]
c1 = jnp.ones_like(x) * c[-1]

_, c0, c1 = jax.lax.fori_loop(3, len(c) + 1, body_fun, (nd, c0, c1))
return c0 + c1 * x
_, c0, c1 = jax.lax.fori_loop(3, len(c) + 1, body_fun, (nd, c0, c1))
return c0 + c1 * x


def quad_points_and_weights(N: int) -> Array:
return leggauss(N)
def quad_points_and_weights(self, N: int) -> Array:
return leggauss(N)


@partial(jax.jit, static_argnums=(1, 2))
def evaluate_basis_derivative(x: Array, deg: int, k: int = 0) -> Array:
return jacn(eval_basis_functions, k)(x, deg)
@partial(jax.jit, static_argnums=(0, 2))
def eval_basis_functions(self, x: float, deg: int) -> Array:
x0 = x * 0 + 1

def inner_loop(
carry: tuple[float, float], i: int
) -> tuple[tuple[float, float], Array]:
x0, x1 = carry
x2 = (x1 * x * (2 * i - 1) - x0 * (i - 1)) / i
return (x1, x2), x1

@partial(jax.jit, static_argnums=1)
def vandermonde(x: Array, deg: int) -> Array:
return evaluate_basis_derivative(x, deg, 0)
_, xs = jax.lax.scan(inner_loop, (x0, x), jnp.arange(2, deg + 1))


@partial(jax.jit, static_argnums=1)
def eval_basis_function(x: float, i: int) -> float:
return evaluate(x, (0,) * i + (1,))


@partial(jax.jit, static_argnums=1)
def eval_basis_functions(x: float, deg: int) -> Array:
x0 = x * 0 + 1

def inner_loop(
carry: tuple[float, float], i: int
) -> tuple[tuple[float, float], Array]:
x0, x1 = carry
x2 = (x1 * x * (2 * i - 1) - x0 * (i - 1)) / i
return (x1, x2), x1

_, xs = jax.lax.scan(inner_loop, (x0, x), jnp.arange(2, deg + 1))

return jnp.hstack((x0, xs))
return jnp.hstack((x0, xs))
Loading

0 comments on commit 3a3d53b

Please sign in to comment.