API reference¶
Solvers¶
Main entry points for integrating ODE systems.
pardax.solve_ivp ¶
solve_ivp(
fun: Callable,
t_span: tuple[float, float],
y0: Float[Array, ...],
stepper: StepperLike,
step_size: float,
params: Any = None,
num_checkpoints: int = 0,
) -> tuple[
Float[Array, " steps"], Float[Array, "steps ..."]
]
Integrate dy/dt = fun(t, y, params) from t_span[0] to t_span[1].
Uses jax.lax.scan internally and is compatible with all JAX
transformations, including reverse-mode differentiation (jax.grad).
Output times are equally spaced: the interval is divided into
num_checkpoints + 1 segments and a snapshot is saved at the end of
each segment, plus the initial state.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fun
|
Callable
|
Right-hand side dy/dt = fun(t, y, params) |
required |
t_span
|
tuple[float, float]
|
Tuple of start and end time, e.g. (t_start, t_end) |
required |
y0
|
Float[Array, ...]
|
Initial condition at t_span[0] |
required |
stepper
|
StepperLike
|
Time-stepper instance (e.g., RK4(), BackwardEuler()) |
required |
step_size
|
float
|
Maximum time step size. The actual step may be smaller to fit an integer number of steps |
required |
params
|
Any
|
Parameters pytree passed to fun |
None
|
num_checkpoints
|
int
|
Number of equally spaced intermediate snapshots to store between t_start and t_end. If 0 (default), only the initial and final states are returned. |
0
|
Returns:
| Name | Type | Description |
|---|---|---|
t |
Float[Array, ' steps']
|
Time points, shape |
y |
Float[Array, 'steps ...']
|
Solution snapshots, shape |
pardax.integrate ¶
integrate(
fun: Callable,
t_eval: Float[Array, " steps"],
y0: Float[Array, ...],
stepper: StepperLike,
step_size_fn: Callable[..., float],
params: Any = None,
) -> tuple[
Float[Array, " steps"], Float[Array, "steps ..."]
]
Integrate dy/dt = fun(t, y, params), returning states at times t_eval.
The time step size at each sub-step is given by
step_size_fn(t, y, params) -> step_size,
which may depend on the current time or solution (e.g. a CFL condition).
Steps are clipped to avoid overshooting each target time.
Uses jax.lax.while_loop internally, which is not supported by
reverse-mode automatic differentiation with jax.grad.
Use solve_ivp when reverse-mode autodiff is required.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fun
|
Callable
|
Right-hand side dy/dt = fun(t, y, params) |
required |
t_eval
|
Float[Array, ' steps']
|
Sorted array of output times |
required |
y0
|
Float[Array, ...]
|
Initial condition at t_eval[0] (any JAX pytree) |
required |
stepper
|
StepperLike
|
Time-stepper instance (e.g., RK4(), BackwardEuler()) |
required |
step_size_fn
|
Callable[..., float]
|
Callable |
required |
params
|
Any
|
Parameters pytree passed to fun and step_size_fn |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
t |
Float[Array, ' steps']
|
Time points, shape |
y |
Float[Array, 'steps ...']
|
Solution snapshots, shape |
Time-stepping schemes¶
Explicit and implicit methods. All built-in steppers inherit from AbstractStepper. Custom steppers can also be used via the StepperLike protocol (see Extending the solver for IMEX and other examples).
Explicit¶
pardax.ForwardEuler ¶
Bases: AbstractStepper
Forward Euler method.
__call__ ¶
__call__(
fun: Callable,
t: Float[Array, ""],
y: Float[Array, ...],
step_size: Float[Array, ""],
params: Any = None,
) -> tuple[Float[Array, ...], ForwardEuler]
Perform a single Forward Euler step.
Computes y_next = y + step_size * fun(t, y, params).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fun
|
Callable
|
Right-hand side fun(t, y, params) -> dy/dt |
required |
t
|
Float[Array, '']
|
Current time |
required |
y
|
Float[Array, ...]
|
Current solution |
required |
step_size
|
Float[Array, '']
|
Time step size |
required |
params
|
Any
|
Parameters passed through to fun |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ...], ForwardEuler]
|
Tuple of the new solution and stepper instance (y, stepper) |
pardax.RK4 ¶
Bases: AbstractStepper
Fourth (4th) order Runge-Kutta method.
__call__ ¶
__call__(
fun: Callable,
t: Float[Array, ""],
y: Float[Array, ...],
step_size: Float[Array, ""],
params: Any = None,
) -> tuple[Float[Array, ...], RK4]
Perform a single RK4 step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fun
|
Callable
|
Right-hand side fun(t, y, params) -> dy/dt |
required |
t
|
Float[Array, '']
|
Current time |
required |
y
|
Float[Array, ...]
|
Current solution |
required |
step_size
|
Float[Array, '']
|
Time step size |
required |
params
|
Any
|
Parameters passed through to fun |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ...], RK4]
|
Tuple of the new solution and stepper instance (y, stepper) |
Implicit¶
pardax.BackwardEuler ¶
BackwardEuler(
root_finder: AbstractRootFinder = NewtonRaphson(),
)
Bases: AbstractStepper
Backward Euler time stepper.
__call__ ¶
__call__(
fun: Callable,
t: Float[Array, ""],
y: Float[Array, ...],
step_size: Float[Array, ""],
params: Any = None,
) -> tuple[Float[Array, ...], BackwardEuler]
Perform a single Backward Euler step.
Solves y_next = y + step_size * fun(t + step_size, y_next, params).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fun
|
Callable
|
Right-hand side fun(t, y, params) -> dy/dt |
required |
t
|
Float[Array, '']
|
Current time |
required |
y
|
Float[Array, ...]
|
Current solution |
required |
step_size
|
Float[Array, '']
|
Time step size |
required |
params
|
Any
|
Parameters passed through to fun |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ...], BackwardEuler]
|
Tuple of the new solution and stepper instance (y, stepper) |
Base classes¶
pardax.AbstractStepper ¶
Bases: Module
Base class for single-term time-stepping methods.
__call__
abstractmethod
¶
__call__(
fun: Callable,
t: Float[Array, ""],
y: Float[Array, ...],
step_size: Float[Array, ""],
params: Any = None,
) -> tuple[Float[Array, ...], AbstractStepper]
Advance the solution by one time step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fun
|
Callable
|
Right-hand side function fun(t, y, params) -> dy/dt |
required |
t
|
Float[Array, '']
|
Current time (0-dimensional JAX array) |
required |
y
|
Float[Array, ...]
|
Current solution |
required |
step_size
|
Float[Array, '']
|
Time step size (0-dimensional JAX array) |
required |
params
|
Any
|
Parameters passed through to fun |
None
|
Returns:
| Type | Description |
|---|---|
tuple[Float[Array, ...], AbstractStepper]
|
Tuple of the new solution and stepper instance (y, stepper) |
pardax.StepperLike ¶
Bases: Protocol
Protocol for any object with a compatible call method.
Root finders¶
Root finders solve the non-linear or linear system that arises at each implicit time step. NewtonRaphson handles general non-linear problems; LinearRootFinder solves linear systems in a single step.
pardax.NewtonRaphson ¶
NewtonRaphson(
lineariser: AbstractLineariser = AutoJVP(),
tol: float = 1e-06,
maxiter: int = 50,
)
Bases: AbstractRootFinder
Newton-Raphson root-finding algorithm.
Iterative update: y_new = y_curr - J^{-1}(y_curr) * R(y_curr)
Attributes:
| Name | Type | Description |
|---|---|---|
lineariser |
AbstractLineariser
|
Strategy for linearising and solving the Newton system |
tol |
float
|
Convergence tolerance for residual norm |
maxiter |
int
|
Maximum number of Newton-Raphson iterations |
__call__ ¶
__call__(
residual_fn: Callable[
[Float[Array, "*state"]], Float[Array, "*state"]
],
y_guess: Float[Array, "*state"],
fun: Callable[..., Float[Array, "*state"]],
t: Float[Array, ""],
step_size: Float[Array, ""],
args: tuple,
theta: float = 1.0,
) -> Float[Array, "*state"]
Find the root of residual_fn(y) = 0 using Newton-Raphson method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
residual_fn
|
Callable[[Float[Array, '*state']], Float[Array, '*state']]
|
Residual function R(y) |
required |
y_guess
|
Float[Array, '*state']
|
Initial guess |
required |
fun
|
Callable[..., Float[Array, '*state']]
|
Right-hand side function (passed to lineariser) |
required |
t
|
Float[Array, '']
|
Time at which to evaluate the linearisation |
required |
step_size
|
Float[Array, '']
|
Time step size |
required |
args
|
tuple
|
Additional arguments to pass to fun |
required |
theta
|
float
|
Implicit coefficient (see AbstractRootFinder) |
1.0
|
Returns:
| Type | Description |
|---|---|
Float[Array, '*state']
|
Solution y |
pardax.LinearRootFinder ¶
LinearRootFinder(
linsolver: AbstractLinearSolver,
operator: AbstractLinearOperator,
)
Bases: AbstractRootFinder
Single-step root finder for linear systems.
Attributes:
| Name | Type | Description |
|---|---|---|
linsolver |
AbstractLinearSolver
|
Linear solver to use |
operator |
AbstractLinearOperator
|
Constructs the linear system of equations |
__call__ ¶
__call__(
residual_fn: Callable[
[Float[Array, "*state"]], Float[Array, "*state"]
],
y_guess: Float[Array, "*state"],
fun: Any,
t: Float[Array, ""],
step_size: Float[Array, ""],
args: tuple,
theta: float = 1.0,
) -> Float[Array, "*state"]
Solve for R(y) = 0 in a single step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
residual_fn
|
Callable[[Float[Array, '*state']], Float[Array, '*state']]
|
Residual function R(y) = y - y_n - theta * step_size * L(y) |
required |
y_guess
|
Float[Array, '*state']
|
Initial guess (used for shape / initial value) |
required |
fun
|
Any
|
Right-hand side of system (unused) |
required |
t
|
Float[Array, '']
|
Current time |
required |
step_size
|
Float[Array, '']
|
Time step size |
required |
args
|
tuple
|
Additional arguments |
required |
theta
|
float
|
Implicit coefficient (see AbstractRootFinder) |
1.0
|
Returns:
| Type | Description |
|---|---|
Float[Array, '*state']
|
Solution y such that residual_fn(y) ≈ 0 |
pardax.AbstractRootFinder ¶
Bases: Module
Base class for root-finding algorithms.
__call__
abstractmethod
¶
__call__(
residual_fn: Callable[
[Float[Array, "*state"]], Float[Array, "*state"]
],
y_guess: Float[Array, "*state"],
fun: Callable[..., Float[Array, "*state"]],
t: Float[Array, ""],
step_size: Float[Array, ""],
args: tuple,
theta: float = 1.0,
) -> Float[Array, "*state"]
Find the root of residual_fn(y) = 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
residual_fn
|
Callable[[Float[Array, '*state']], Float[Array, '*state']]
|
Function mapping y -> R(y), where we seek R(y) = 0 |
required |
y_guess
|
Float[Array, '*state']
|
Initial guess for the solution |
required |
fun
|
Callable[..., Float[Array, '*state']]
|
Right-hand side function dy/dt = fun(t, y, *args) |
required |
t
|
Float[Array, '']
|
Time at which to evaluate the linearisation |
required |
step_size
|
Float[Array, '']
|
Time step size |
required |
args
|
tuple
|
Additional arguments to pass to fun |
required |
theta
|
float
|
Implicit coefficient. The system matrix is built as (I - theta * step_size * L). For backward Euler theta=1.0, for Crank-Nicolson theta=0.5, etc. Passed by the time stepper; users do not set this directly. |
1.0
|
Returns:
| Type | Description |
|---|---|
Float[Array, '*state']
|
Solution y such that residual_fn(y) ≈ 0 |
Linearisers¶
Linearisers construct the Newton system inside NewtonRaphson. They bundle a linearisation strategy (autodiff, user-provided JVP, or dense Jacobian) with a linear solver.
pardax.AutoJVP ¶
AutoJVP(linsolver: AbstractLinearSolver = GMRES())
pardax.JVP ¶
JVP(
jvp_fn: Callable,
linsolver: AbstractLinearSolver = GMRES(),
)
pardax.Jacobian ¶
Jacobian(
jac_fn: Callable,
linsolver: AbstractLinearSolver = DirectDense(),
)
pardax.AbstractLineariser ¶
Bases: Module
Base class for locally linearising a system of equations.
__call__
abstractmethod
¶
__call__(
fun: Callable[..., Float[Array, "*state"]],
t: Float[Array, ""],
step_size: Float[Array, ""],
args: tuple,
) -> Callable[
[Float[Array, "*state"], Float[Array, "*state"]],
Float[Array, "*state"],
]
Returns a solve function: (residual_at_y, y) --> delta_y
Linear solvers¶
Solve the linear system \(Ax = b\) that arises during root finding. Iterative solvers accept both dense matrices and matrix-free operators.
pardax.DirectDense ¶
Bases: AbstractLinearSolver
Direct solver for dense linear systems.
Dispatches to jax.numpy.linalg.solve.
Only suitable for small systems where the Jacobian is provided explicitly.
__call__ ¶
__call__(
A: Union[
Float[Array, "n n"],
Callable[[Float[Array, " n"]], Float[Array, " n"]],
],
b: Float[Array, " n"],
x0: Optional[Float[Array, " n"]] = None,
) -> Float[Array, " n"]
Solve A*x = b.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Union[Float[Array, 'n n'], Callable[[Float[Array, ' n']], Float[Array, ' n']]]
|
Dense matrix |
required |
b
|
Float[Array, ' n']
|
Right-hand side vector |
required |
x0
|
Optional[Float[Array, ' n']]
|
Ignored (kept for interface compatibility). Can be None. |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' n']
|
Solution x |
Raises:
| Type | Description |
|---|---|
TypeError
|
If A is a callable (linear operator) instead of a matrix |
pardax.GMRES ¶
GMRES(tol: float = 1e-06, maxiter: int = 100)
Bases: AbstractLinearSolver
Generalised Minimal Residual (GMRES).
Dispatches to jax.scipy.sparse.linalg.gmres.
Suitable for general non-symmetric systems.
__call__ ¶
__call__(
A: Union[
Float[Array, "n n"],
Callable[[Float[Array, " n"]], Float[Array, " n"]],
],
b: Float[Array, " n"],
x0: Optional[Float[Array, " n"]] = None,
) -> Float[Array, " n"]
Solve A*x = b using GMRES.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Union[Float[Array, 'n n'], Callable[[Float[Array, ' n']], Float[Array, ' n']]]
|
Dense matrix or linear operator with signature x -> A*x |
required |
b
|
Float[Array, ' n']
|
Right-hand side vector |
required |
x0
|
Optional[Float[Array, ' n']]
|
Initial guess vector |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' n']
|
Approximate solution x |
pardax.CG ¶
CG(tol: float = 1e-06, maxiter: int = 100)
Bases: AbstractLinearSolver
Conjugate Gradient (CG).
Dispatches to jax.scipy.sparse.linalg.cg.
Only suitable for symmetric and positive-definite systems.
Attributes:
| Name | Type | Description |
|---|---|---|
tol |
float
|
Convergence tolerance for residual norm |
maxiter |
int
|
Maximum number of iterations |
__call__ ¶
__call__(
A: Union[
Float[Array, "n n"],
Callable[[Float[Array, " n"]], Float[Array, " n"]],
],
b: Float[Array, " n"],
x0: Optional[Float[Array, " n"]] = None,
) -> Float[Array, " n"]
Solve A*x = b.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Union[Float[Array, 'n n'], Callable[[Float[Array, ' n']], Float[Array, ' n']]]
|
Dense matrix or linear operator with signature x -> A*x |
required |
b
|
Float[Array, ' n']
|
Right-hand side vector |
required |
x0
|
Optional[Float[Array, ' n']]
|
Initial guess vector |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' n']
|
Approximate solution x |
pardax.BiCGStab ¶
BiCGStab(tol: float = 1e-06, maxiter: int = 100)
Bases: AbstractLinearSolver
Stabilised Biconjugate Gradient (BiCGStab).
Dispatches to jax.scipy.sparse.linalg.bicgstab.
Suitable for non-symmetric systems.
Attributes:
| Name | Type | Description |
|---|---|---|
tol |
float
|
Convergence tolerance for residual norm |
maxiter |
int
|
Maximum number of iterations |
__call__ ¶
__call__(
A: Union[
Float[Array, "n n"],
Callable[[Float[Array, " n"]], Float[Array, " n"]],
],
b: Float[Array, " n"],
x0: Optional[Float[Array, " n"]] = None,
) -> Float[Array, " n"]
Solve A*x = b.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Union[Float[Array, 'n n'], Callable[[Float[Array, ' n']], Float[Array, ' n']]]
|
Dense matrix or linear operator with signature x -> A*x |
required |
b
|
Float[Array, ' n']
|
Right-hand side vector |
required |
x0
|
Optional[Float[Array, ' n']]
|
Initial guess vector |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' n']
|
Approximate solution x |
pardax.SpectralSolver ¶
SpectralSolver(
forward: Callable[[Array], Array],
backward: Callable[[Array], Array],
constraint: Callable[[Array], Array] = _identity,
)
Bases: AbstractLinearSolver
Spectral linear solver.
Solves a diagonalised system by transforming to the spectral basis, performing a pointwise division by the symbol, and transforming back.
Pair with SpectralOperator, which passes the precomputed symbol
array 1 - h * eigvals as A.
Attributes:
| Name | Type | Description |
|---|---|---|
forward |
Callable[[Array], Array]
|
Forward transformation to diagonal basis |
backward |
Callable[[Array], Array]
|
Inverse transformation from diagonal basis |
constraint |
Callable[[Array], Array]
|
Pre-processing to enforce compatibility (e.g. mean removal) |
__call__ ¶
__call__(
A: Union[
Float[Array, " n"],
Callable[[Float[Array, " n"]], Float[Array, " n"]],
],
b: Float[Array, " n"],
x0: Optional[Float[Array, " n"]] = None,
) -> Float[Array, " n"]
Solve the system given the spectral symbol.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Union[Float[Array, ' n'], Callable[[Float[Array, ' n']], Float[Array, ' n']]]
|
Symbol array |
required |
b
|
Float[Array, ' n']
|
Right-hand side vector. |
required |
x0
|
Optional[Float[Array, ' n']]
|
Ignored. |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' n']
|
Solution x. |
pardax.AbstractLinearSolver ¶
Bases: Module
Base class for linear solvers.
__call__
abstractmethod
¶
__call__(
A: Union[
Float[Array, ...],
Callable[[Float[Array, " n"]], Float[Array, " n"]],
],
b: Float[Array, " n"],
x0: Optional[Float[Array, " n"]] = None,
) -> Float[Array, " n"]
Solve the linear system A*x = b.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Union[Float[Array, ...], Callable[[Float[Array, ' n']], Float[Array, ' n']]]
|
Dense matrix or linear operator with signature x -> A*x |
required |
b
|
Float[Array, ' n']
|
Right-hand side vector |
required |
x0
|
Optional[Float[Array, ' n']]
|
Initial guess vector |
None
|
Returns:
| Type | Description |
|---|---|
Float[Array, ' n']
|
Solution vector x such that A*x ≈ b |
Linear operators¶
Linear operators build the implicit system \((I - h L)\) for use with LinearRootFinder. Each operator returns the system in the form expected by its paired linear solver.
pardax.DenseOperator ¶
DenseOperator(op_fn: Callable)
Bases: AbstractLinearOperator
Linear operator that returns a dense system matrix (I - h * L).
Pair with DirectDense.
Attributes:
| Name | Type | Description |
|---|---|---|
op_fn |
Callable
|
Callable (t, *args) -> L, returning the operator as a dense matrix. |
pardax.MatrixFreeOperator ¶
MatrixFreeOperator(op_fn: Callable)
Bases: AbstractLinearOperator
Linear operator that returns a matrix-free matvec for (I - h * L).
Pair with an iterative solver (GMRES, CG, BiCGStab).
Attributes:
| Name | Type | Description |
|---|---|---|
op_fn |
Callable
|
Callable (t, *args) -> Lv, where Lv is a matvec v -> L @ v. |
pardax.SpectralOperator ¶
SpectralOperator(eigvals: Float[Array, ' n'])
Bases: AbstractLinearOperator
Linear operator diagonalised by a known spectral transform.
Builds and returns the spectral symbol 1 - h * eigvals.
Pair with SpectralSolver.
Attributes:
| Name | Type | Description |
|---|---|---|
eigvals |
Float[Array, ' n']
|
Eigenvalues of the linear operator. |
pardax.AbstractLinearOperator ¶
Bases: Module
Wraps a linear operator and builds the implicit system (I - h * L).
Subclasses return the system in the form expected by their paired
AbstractLinearSolver: a dense matrix, a matvec callable, or
a spectral symbol array.
system
abstractmethod
¶
system(
t: Float[Array, ""], h: Float[Array, ""], args: tuple
) -> Union[
Float[Array, ...],
Callable[[Float[Array, " n"]], Float[Array, " n"]],
]
Build and return (I - h * L) for the current time step.
Transforms¶
Discrete sine transforms for use with Dirichlet boundary conditions.
pardax.transform.dst1 ¶
dst1(x: Float[Array, ' n']) -> Float[Array, ' n']
Discrete sine transform (type 1).
pardax.transform.idst1 ¶
idst1(X: Float[Array, ' n']) -> Float[Array, ' n']
Inverse discrete sine transform (type 1).