Skip to content

Understanding the ODE integrator

The jax_fno.integrate module provides JAX-compatible time integration methods for solving initial value problems (IVPs). The solver performs temporal discretisation and integration, while users need to handle spatial discretisation during setup.

Quick start

import jax.numpy as jnp
from jax_fno.integrate import solve_ivp, RK4

# 1. Define your discretised PDE as an ODE
# NOTE: This must be written in a JAX-compatible (functionally pure) way
def my_pde_rhs(t, y, args):
    """Right-hand side: dy/dt = f(t, y, ...)

    Implement your spatial discretisation here.
    Handle boundary conditions within this function.
    """
    ...

# 2. Set initial condition
y0 = ...

# 3. Choose time-stepping method
method = RK4()

# 4. Integrate
t_final, y_final = integrate.solve_ivp(
    my_pde_rhs,
    t_span=(0.0, 1.0),
    y0=y0,
    method=method,
    step_size=0.001,
    args=(...,)
)

Time-stepping methods

Explicit methods

Explicit methods compute the next state directly from the current state. They're simple and fast but can require very small time-steps for stiff problems.

Implicit methods

Implicit methods use a root-finding algorithm at each time step, and root-finding algorithms often use linear solvers at each iteration. Implicit methods are usually more expensive than explicit methods per step but can allow much larger time steps for stiff problems.

Extending the time-stepping methods

To use a custom time-stepping method, implement a step method according to StepperProtocol in a class that inherits from flax.nnx.Module:

from jax import Array
from flax import nnx
from jax_fno.integrate import solve_ivp

class MyMethod(nnx.Module):

    def step(
        self,
        fun: Callable,
        t: Array,
        y: Array,
        h: Array,
        args: tuple = ()
    ) -> Array:
        """Advance one time step."""
        ...

t, y = solve_ivp(fun, t_span, y0, MyMethod(), step_size, args)

Extending the root-finding algorithms

Currently only NewtonRaphson is provided as root-finding algorithm.

Users can extend the root finders by writing their own implementation by following RootFinderProtocol.

Extending the linear solvers

The root-finders often use a linear solver as a subroutine.

The linear solvers currently available are:

Users can extend the linear solvers by writing their own implementation and inheriting from LinearSolverProtocol.

JAX transformations

As long as fun and method are JAX-compatible, solve_ivp should support most JAX transformations, however these features have not been properly tested yet.

JIT Compilation:

import jax
from jax_fno.integrate import solve_ivp

# JIT-compile the entire integration
solve_jit = jax.jit(
    solve_ivp(fun, t_span, y0, method, h, args),
    static_argnames=['fun', 'method']
)

y_final = solve_jit(y0)

Vectorisation (batching):

# Integrate multiple initial conditions in parallel
y0_batch = jnp.stack([y0_1, y0_2, y0_3])  # (batch, n)

solve_batch = jax.vmap(
    lambda y_: solve_ivp(fun, t_span, y_, method, dt, args)[1]
)

y_final_batch = solve_batch(y0_batch)  # (batch, n)