Skip to content

API reference

Solvers

Main entry points for integrating ODE systems.

pardax.solve_ivp

solve_ivp(
    fun: Callable,
    t_span: tuple[float, float],
    y0: Float[Array, ...],
    stepper: StepperLike,
    step_size: float,
    params: Any = None,
    num_checkpoints: int = 0,
) -> tuple[
    Float[Array, " steps"], Float[Array, "steps ..."]
]

Integrate dy/dt = fun(t, y, params) from t_span[0] to t_span[1].

Uses jax.lax.scan internally and is compatible with all JAX transformations, including reverse-mode differentiation (jax.grad).

Output times are equally spaced: the interval is divided into num_checkpoints + 1 segments and a snapshot is saved at the end of each segment, plus the initial state.

Parameters:

Name Type Description Default
fun Callable

Right-hand side dy/dt = fun(t, y, params)

required
t_span tuple[float, float]

Tuple of start and end time, e.g. (t_start, t_end)

required
y0 Float[Array, ...]

Initial condition at t_span[0]

required
stepper StepperLike

Time-stepper instance (e.g., RK4(), BackwardEuler())

required
step_size float

Maximum time step size. The actual step may be smaller to fit an integer number of steps

required
params Any

Parameters pytree passed to fun

None
num_checkpoints int

Number of equally spaced intermediate snapshots to store between t_start and t_end. If 0 (default), only the initial and final states are returned.

0

Returns:

Name Type Description
t Float[Array, ' steps']

Time points, shape (num_checkpoints + 2,)

y Float[Array, 'steps ...']

Solution snapshots, shape (num_checkpoints + 2, *y0.shape)

pardax.integrate

integrate(
    fun: Callable,
    t_eval: Float[Array, " steps"],
    y0: Float[Array, ...],
    stepper: StepperLike,
    step_size_fn: Callable[..., float],
    params: Any = None,
) -> tuple[
    Float[Array, " steps"], Float[Array, "steps ..."]
]

Integrate dy/dt = fun(t, y, params), returning states at times t_eval.

The time step size at each sub-step is given by step_size_fn(t, y, params) -> step_size, which may depend on the current time or solution (e.g. a CFL condition). Steps are clipped to avoid overshooting each target time.

Uses jax.lax.while_loop internally, which is not supported by reverse-mode automatic differentiation with jax.grad. Use solve_ivp when reverse-mode autodiff is required.

Parameters:

Name Type Description Default
fun Callable

Right-hand side dy/dt = fun(t, y, params)

required
t_eval Float[Array, ' steps']

Sorted array of output times

required
y0 Float[Array, ...]

Initial condition at t_eval[0] (any JAX pytree)

required
stepper StepperLike

Time-stepper instance (e.g., RK4(), BackwardEuler())

required
step_size_fn Callable[..., float]

Callable (t, y, params) -> dt returning the desired step size from the current state

required
params Any

Parameters pytree passed to fun and step_size_fn

None

Returns:

Name Type Description
t Float[Array, ' steps']

Time points, shape (len(t_eval),)

y Float[Array, 'steps ...']

Solution snapshots, shape (len(t_eval), *y0.shape)

Time-stepping schemes

Explicit and implicit methods. All built-in steppers inherit from AbstractStepper. Custom steppers can also be used via the StepperLike protocol (see Extending the solver for IMEX and other examples).

Explicit

pardax.ForwardEuler

Bases: AbstractStepper

Forward Euler method.

__call__

__call__(
    fun: Callable,
    t: Float[Array, ""],
    y: Float[Array, ...],
    step_size: Float[Array, ""],
    params: Any = None,
) -> tuple[Float[Array, ...], ForwardEuler]

Perform a single Forward Euler step.

Computes y_next = y + step_size * fun(t, y, params).

Parameters:

Name Type Description Default
fun Callable

Right-hand side fun(t, y, params) -> dy/dt

required
t Float[Array, '']

Current time

required
y Float[Array, ...]

Current solution

required
step_size Float[Array, '']

Time step size

required
params Any

Parameters passed through to fun

None

Returns:

Type Description
tuple[Float[Array, ...], ForwardEuler]

Tuple of the new solution and stepper instance (y, stepper)

pardax.RK4

Bases: AbstractStepper

Fourth (4th) order Runge-Kutta method.

__call__

__call__(
    fun: Callable,
    t: Float[Array, ""],
    y: Float[Array, ...],
    step_size: Float[Array, ""],
    params: Any = None,
) -> tuple[Float[Array, ...], RK4]

Perform a single RK4 step.

Parameters:

Name Type Description Default
fun Callable

Right-hand side fun(t, y, params) -> dy/dt

required
t Float[Array, '']

Current time

required
y Float[Array, ...]

Current solution

required
step_size Float[Array, '']

Time step size

required
params Any

Parameters passed through to fun

None

Returns:

Type Description
tuple[Float[Array, ...], RK4]

Tuple of the new solution and stepper instance (y, stepper)

Implicit

pardax.BackwardEuler

BackwardEuler(
    root_finder: AbstractRootFinder = NewtonRaphson(),
)

Bases: AbstractStepper

Backward Euler time stepper.

__call__

__call__(
    fun: Callable,
    t: Float[Array, ""],
    y: Float[Array, ...],
    step_size: Float[Array, ""],
    params: Any = None,
) -> tuple[Float[Array, ...], BackwardEuler]

Perform a single Backward Euler step.

Solves y_next = y + step_size * fun(t + step_size, y_next, params).

Parameters:

Name Type Description Default
fun Callable

Right-hand side fun(t, y, params) -> dy/dt

required
t Float[Array, '']

Current time

required
y Float[Array, ...]

Current solution

required
step_size Float[Array, '']

Time step size

required
params Any

Parameters passed through to fun

None

Returns:

Type Description
tuple[Float[Array, ...], BackwardEuler]

Tuple of the new solution and stepper instance (y, stepper)

Base classes

pardax.AbstractStepper

Bases: Module

Base class for single-term time-stepping methods.

__call__ abstractmethod

__call__(
    fun: Callable,
    t: Float[Array, ""],
    y: Float[Array, ...],
    step_size: Float[Array, ""],
    params: Any = None,
) -> tuple[Float[Array, ...], AbstractStepper]

Advance the solution by one time step.

Parameters:

Name Type Description Default
fun Callable

Right-hand side function fun(t, y, params) -> dy/dt

required
t Float[Array, '']

Current time (0-dimensional JAX array)

required
y Float[Array, ...]

Current solution

required
step_size Float[Array, '']

Time step size (0-dimensional JAX array)

required
params Any

Parameters passed through to fun

None

Returns:

Type Description
tuple[Float[Array, ...], AbstractStepper]

Tuple of the new solution and stepper instance (y, stepper)

pardax.StepperLike

Bases: Protocol

Protocol for any object with a compatible call method.

Root finders

Root finders solve the non-linear or linear system that arises at each implicit time step. NewtonRaphson handles general non-linear problems; LinearRootFinder solves linear systems in a single step.

pardax.NewtonRaphson

NewtonRaphson(
    lineariser: AbstractLineariser = AutoJVP(),
    tol: float = 1e-06,
    maxiter: int = 50,
)

Bases: AbstractRootFinder

Newton-Raphson root-finding algorithm.

Iterative update: y_new = y_curr - J^{-1}(y_curr) * R(y_curr)

Attributes:

Name Type Description
lineariser AbstractLineariser

Strategy for linearising and solving the Newton system

tol float

Convergence tolerance for residual norm

maxiter int

Maximum number of Newton-Raphson iterations

__call__

__call__(
    residual_fn: Callable[
        [Float[Array, "*state"]], Float[Array, "*state"]
    ],
    y_guess: Float[Array, "*state"],
    fun: Callable[..., Float[Array, "*state"]],
    t: Float[Array, ""],
    step_size: Float[Array, ""],
    args: tuple,
    theta: float = 1.0,
) -> Float[Array, "*state"]

Find the root of residual_fn(y) = 0 using Newton-Raphson method.

Parameters:

Name Type Description Default
residual_fn Callable[[Float[Array, '*state']], Float[Array, '*state']]

Residual function R(y)

required
y_guess Float[Array, '*state']

Initial guess

required
fun Callable[..., Float[Array, '*state']]

Right-hand side function (passed to lineariser)

required
t Float[Array, '']

Time at which to evaluate the linearisation

required
step_size Float[Array, '']

Time step size

required
args tuple

Additional arguments to pass to fun

required
theta float

Implicit coefficient (see AbstractRootFinder)

1.0

Returns:

Type Description
Float[Array, '*state']

Solution y

pardax.LinearRootFinder

LinearRootFinder(
    linsolver: AbstractLinearSolver,
    operator: AbstractLinearOperator,
)

Bases: AbstractRootFinder

Single-step root finder for linear systems.

Attributes:

Name Type Description
linsolver AbstractLinearSolver

Linear solver to use

operator AbstractLinearOperator

Constructs the linear system of equations

__call__

__call__(
    residual_fn: Callable[
        [Float[Array, "*state"]], Float[Array, "*state"]
    ],
    y_guess: Float[Array, "*state"],
    fun: Any,
    t: Float[Array, ""],
    step_size: Float[Array, ""],
    args: tuple,
    theta: float = 1.0,
) -> Float[Array, "*state"]

Solve for R(y) = 0 in a single step.

Parameters:

Name Type Description Default
residual_fn Callable[[Float[Array, '*state']], Float[Array, '*state']]

Residual function R(y) = y - y_n - theta * step_size * L(y)

required
y_guess Float[Array, '*state']

Initial guess (used for shape / initial value)

required
fun Any

Right-hand side of system (unused)

required
t Float[Array, '']

Current time

required
step_size Float[Array, '']

Time step size

required
args tuple

Additional arguments

required
theta float

Implicit coefficient (see AbstractRootFinder)

1.0

Returns:

Type Description
Float[Array, '*state']

Solution y such that residual_fn(y) ≈ 0

pardax.AbstractRootFinder

Bases: Module

Base class for root-finding algorithms.

__call__ abstractmethod

__call__(
    residual_fn: Callable[
        [Float[Array, "*state"]], Float[Array, "*state"]
    ],
    y_guess: Float[Array, "*state"],
    fun: Callable[..., Float[Array, "*state"]],
    t: Float[Array, ""],
    step_size: Float[Array, ""],
    args: tuple,
    theta: float = 1.0,
) -> Float[Array, "*state"]

Find the root of residual_fn(y) = 0.

Parameters:

Name Type Description Default
residual_fn Callable[[Float[Array, '*state']], Float[Array, '*state']]

Function mapping y -> R(y), where we seek R(y) = 0

required
y_guess Float[Array, '*state']

Initial guess for the solution

required
fun Callable[..., Float[Array, '*state']]

Right-hand side function dy/dt = fun(t, y, *args)

required
t Float[Array, '']

Time at which to evaluate the linearisation

required
step_size Float[Array, '']

Time step size

required
args tuple

Additional arguments to pass to fun

required
theta float

Implicit coefficient. The system matrix is built as (I - theta * step_size * L). For backward Euler theta=1.0, for Crank-Nicolson theta=0.5, etc. Passed by the time stepper; users do not set this directly.

1.0

Returns:

Type Description
Float[Array, '*state']

Solution y such that residual_fn(y) ≈ 0

Linearisers

Linearisers construct the Newton system inside NewtonRaphson. They bundle a linearisation strategy (autodiff, user-provided JVP, or dense Jacobian) with a linear solver.

pardax.AutoJVP

AutoJVP(linsolver: AbstractLinearSolver = GMRES())

Bases: AbstractLineariser

Linearise using automatic differentiation via jax.jvp.

pardax.JVP

JVP(
    jvp_fn: Callable,
    linsolver: AbstractLinearSolver = GMRES(),
)

Bases: AbstractLineariser

User-provided matrix-free Jacobian-vector product.

pardax.Jacobian

Jacobian(
    jac_fn: Callable,
    linsolver: AbstractLinearSolver = DirectDense(),
)

Bases: AbstractLineariser

User-provided dense Jacobian.

pardax.AbstractLineariser

Bases: Module

Base class for locally linearising a system of equations.

__call__ abstractmethod

__call__(
    fun: Callable[..., Float[Array, "*state"]],
    t: Float[Array, ""],
    step_size: Float[Array, ""],
    args: tuple,
) -> Callable[
    [Float[Array, "*state"], Float[Array, "*state"]],
    Float[Array, "*state"],
]

Returns a solve function: (residual_at_y, y) --> delta_y

Linear solvers

Solve the linear system \(Ax = b\) that arises during root finding. Iterative solvers accept both dense matrices and matrix-free operators.

pardax.DirectDense

Bases: AbstractLinearSolver

Direct solver for dense linear systems.

Dispatches to jax.numpy.linalg.solve. Only suitable for small systems where the Jacobian is provided explicitly.

__call__

__call__(
    A: Union[
        Float[Array, "n n"],
        Callable[[Float[Array, " n"]], Float[Array, " n"]],
    ],
    b: Float[Array, " n"],
    x0: Optional[Float[Array, " n"]] = None,
) -> Float[Array, " n"]

Solve A*x = b.

Parameters:

Name Type Description Default
A Union[Float[Array, 'n n'], Callable[[Float[Array, ' n']], Float[Array, ' n']]]

Dense matrix

required
b Float[Array, ' n']

Right-hand side vector

required
x0 Optional[Float[Array, ' n']]

Ignored (kept for interface compatibility). Can be None.

None

Returns:

Type Description
Float[Array, ' n']

Solution x

Raises:

Type Description
TypeError

If A is a callable (linear operator) instead of a matrix

pardax.GMRES

GMRES(tol: float = 1e-06, maxiter: int = 100)

Bases: AbstractLinearSolver

Generalised Minimal Residual (GMRES).

Dispatches to jax.scipy.sparse.linalg.gmres. Suitable for general non-symmetric systems.

__call__

__call__(
    A: Union[
        Float[Array, "n n"],
        Callable[[Float[Array, " n"]], Float[Array, " n"]],
    ],
    b: Float[Array, " n"],
    x0: Optional[Float[Array, " n"]] = None,
) -> Float[Array, " n"]

Solve A*x = b using GMRES.

Parameters:

Name Type Description Default
A Union[Float[Array, 'n n'], Callable[[Float[Array, ' n']], Float[Array, ' n']]]

Dense matrix or linear operator with signature x -> A*x

required
b Float[Array, ' n']

Right-hand side vector

required
x0 Optional[Float[Array, ' n']]

Initial guess vector

None

Returns:

Type Description
Float[Array, ' n']

Approximate solution x

pardax.CG

CG(tol: float = 1e-06, maxiter: int = 100)

Bases: AbstractLinearSolver

Conjugate Gradient (CG).

Dispatches to jax.scipy.sparse.linalg.cg. Only suitable for symmetric and positive-definite systems.

Attributes:

Name Type Description
tol float

Convergence tolerance for residual norm

maxiter int

Maximum number of iterations

__call__

__call__(
    A: Union[
        Float[Array, "n n"],
        Callable[[Float[Array, " n"]], Float[Array, " n"]],
    ],
    b: Float[Array, " n"],
    x0: Optional[Float[Array, " n"]] = None,
) -> Float[Array, " n"]

Solve A*x = b.

Parameters:

Name Type Description Default
A Union[Float[Array, 'n n'], Callable[[Float[Array, ' n']], Float[Array, ' n']]]

Dense matrix or linear operator with signature x -> A*x

required
b Float[Array, ' n']

Right-hand side vector

required
x0 Optional[Float[Array, ' n']]

Initial guess vector

None

Returns:

Type Description
Float[Array, ' n']

Approximate solution x

pardax.BiCGStab

BiCGStab(tol: float = 1e-06, maxiter: int = 100)

Bases: AbstractLinearSolver

Stabilised Biconjugate Gradient (BiCGStab).

Dispatches to jax.scipy.sparse.linalg.bicgstab. Suitable for non-symmetric systems.

Attributes:

Name Type Description
tol float

Convergence tolerance for residual norm

maxiter int

Maximum number of iterations

__call__

__call__(
    A: Union[
        Float[Array, "n n"],
        Callable[[Float[Array, " n"]], Float[Array, " n"]],
    ],
    b: Float[Array, " n"],
    x0: Optional[Float[Array, " n"]] = None,
) -> Float[Array, " n"]

Solve A*x = b.

Parameters:

Name Type Description Default
A Union[Float[Array, 'n n'], Callable[[Float[Array, ' n']], Float[Array, ' n']]]

Dense matrix or linear operator with signature x -> A*x

required
b Float[Array, ' n']

Right-hand side vector

required
x0 Optional[Float[Array, ' n']]

Initial guess vector

None

Returns:

Type Description
Float[Array, ' n']

Approximate solution x

pardax.SpectralSolver

SpectralSolver(
    forward: Callable[[Array], Array],
    backward: Callable[[Array], Array],
    constraint: Callable[[Array], Array] = _identity,
)

Bases: AbstractLinearSolver

Spectral linear solver.

Solves a diagonalised system by transforming to the spectral basis, performing a pointwise division by the symbol, and transforming back.

Pair with SpectralOperator, which passes the precomputed symbol array 1 - h * eigvals as A.

Attributes:

Name Type Description
forward Callable[[Array], Array]

Forward transformation to diagonal basis

backward Callable[[Array], Array]

Inverse transformation from diagonal basis

constraint Callable[[Array], Array]

Pre-processing to enforce compatibility (e.g. mean removal)

__call__

__call__(
    A: Union[
        Float[Array, " n"],
        Callable[[Float[Array, " n"]], Float[Array, " n"]],
    ],
    b: Float[Array, " n"],
    x0: Optional[Float[Array, " n"]] = None,
) -> Float[Array, " n"]

Solve the system given the spectral symbol.

Parameters:

Name Type Description Default
A Union[Float[Array, ' n'], Callable[[Float[Array, ' n']], Float[Array, ' n']]]

Symbol array 1 - h * eigvals from SpectralOperator.

required
b Float[Array, ' n']

Right-hand side vector.

required
x0 Optional[Float[Array, ' n']]

Ignored.

None

Returns:

Type Description
Float[Array, ' n']

Solution x.

pardax.AbstractLinearSolver

Bases: Module

Base class for linear solvers.

__call__ abstractmethod

__call__(
    A: Union[
        Float[Array, ...],
        Callable[[Float[Array, " n"]], Float[Array, " n"]],
    ],
    b: Float[Array, " n"],
    x0: Optional[Float[Array, " n"]] = None,
) -> Float[Array, " n"]

Solve the linear system A*x = b.

Parameters:

Name Type Description Default
A Union[Float[Array, ...], Callable[[Float[Array, ' n']], Float[Array, ' n']]]

Dense matrix or linear operator with signature x -> A*x

required
b Float[Array, ' n']

Right-hand side vector

required
x0 Optional[Float[Array, ' n']]

Initial guess vector

None

Returns:

Type Description
Float[Array, ' n']

Solution vector x such that A*x ≈ b

Linear operators

Linear operators build the implicit system \((I - h L)\) for use with LinearRootFinder. Each operator returns the system in the form expected by its paired linear solver.

pardax.DenseOperator

DenseOperator(op_fn: Callable)

Bases: AbstractLinearOperator

Linear operator that returns a dense system matrix (I - h * L).

Pair with DirectDense.

Attributes:

Name Type Description
op_fn Callable

Callable (t, *args) -> L, returning the operator as a dense matrix.

pardax.MatrixFreeOperator

MatrixFreeOperator(op_fn: Callable)

Bases: AbstractLinearOperator

Linear operator that returns a matrix-free matvec for (I - h * L).

Pair with an iterative solver (GMRES, CG, BiCGStab).

Attributes:

Name Type Description
op_fn Callable

Callable (t, *args) -> Lv, where Lv is a matvec v -> L @ v.

pardax.SpectralOperator

SpectralOperator(eigvals: Float[Array, ' n'])

Bases: AbstractLinearOperator

Linear operator diagonalised by a known spectral transform.

Builds and returns the spectral symbol 1 - h * eigvals. Pair with SpectralSolver.

Attributes:

Name Type Description
eigvals Float[Array, ' n']

Eigenvalues of the linear operator.

pardax.AbstractLinearOperator

Bases: Module

Wraps a linear operator and builds the implicit system (I - h * L).

Subclasses return the system in the form expected by their paired AbstractLinearSolver: a dense matrix, a matvec callable, or a spectral symbol array.

system abstractmethod

system(
    t: Float[Array, ""], h: Float[Array, ""], args: tuple
) -> Union[
    Float[Array, ...],
    Callable[[Float[Array, " n"]], Float[Array, " n"]],
]

Build and return (I - h * L) for the current time step.

Transforms

Discrete sine transforms for use with Dirichlet boundary conditions.

pardax.transform.dst1

dst1(x: Float[Array, ' n']) -> Float[Array, ' n']

Discrete sine transform (type 1).

pardax.transform.idst1

idst1(X: Float[Array, ' n']) -> Float[Array, ' n']

Inverse discrete sine transform (type 1).