Getting started¶
In pardax, users are expected to handle the spatial discretisation of
their ODE, and pardax handles the rest.
import jax.numpy as jnp
import pardax as pdx
# 1. Define your discretised PDE as an ODE
# NOTE: must be functionally pure (JAX-compatible)
def my_pde_rhs(t, y, params):
"""Right-hand side: dy/dt = f(t, y, params)
Implement your spatial discretisation here.
Handle boundary conditions within this function.
"""
...
# 2. Set initial condition
y0 = ...
# 3. Choose time-stepping method
method = pdx.RK4()
# 4. Integrate
t, y = pdx.solve_ivp(
my_pde_rhs,
t_span=(0.0, 1.0),
y0=y0,
stepper=method,
step_size=0.001,
params={...},
num_checkpoints=10,
)
pardax also provides integrate, which takes a step-size
callback that can depend on the current solution state. It uses
jax.lax.while_loop internally, so does not support reverse-mode automatic
differentiation with jax.grad.
def cfl_step_size(t, u, params):
return 0.5 * params["dx"]**2 / params["nu"]
t, y = pdx.integrate(
fun,
t_eval=jnp.linspace(0.0, 5.0, 11),
y0=y0,
stepper=pdx.RK4(),
step_size_fn=cfl_step_size,
params={"nu": nu, "dx": dx},
)
JAX transformations¶
Because pardax is built on JAX and Equinox,
you can apply JAX transformations directly to solve_ivp.
Vectorisation¶
y0_batch = jnp.stack([y0_1, y0_2, y0_3]) # (batch, n)
solve_batch = jax.vmap(
lambda y_: pdx.solve_ivp(fun, t_span, y_, stepper, step_size, params)
)
t, y_batch = solve_batch(y0_batch)
Differentiation¶
def loss(params):
t, y = pdx.solve_ivp(fun, t_span, y0, stepper, step_size, params=params)
return jnp.mean((y[-1] - y_target)**2)
grads = jax.grad(loss)(params)
JIT compilation¶
import jax
solve_jit = jax.jit(lambda y_: pdx.solve_ivp(
fun, t_span, y_, stepper, step_size, params
))
t, y = solve_jit(y0)
Time-stepping methods¶
Explicit methods¶
Explicit methods compute the next state directly from the current state. They are simple and efficient per step but require small time steps for stiff problems.
method = pdx.ForwardEuler() # first-order
method = pdx.RK4() # fourth-order
Implicit methods¶
Implicit methods solve a non-linear or linear system at each time step. They are more expensive per step but can take much larger time steps for stiff problems (e.g. diffusion-dominated PDEs).
The implicit solver is assembled from composable components:
# Linear solver -> Lineariser -> Root finder -> Time stepper
method = pdx.BackwardEuler(
root_finder=pdx.NewtonRaphson(
lineariser=pdx.AutoJVP(linsolver=pdx.GMRES()),
tol=1e-6,
)
)
For linear problems, you can skip Newton iteration entirely and solve the implicit system in a single step using a LinearRootFinder. See the Burgers' equation tutorial for an example.
Advanced methods¶
pardax is designed so that time-stepping schemes are composable,
which allows users to implement their own schemes that treat PDE terms
separately. See Extending the solver for
more information or the
Burgers' equation tutorial for a
worked example.