Integration (jax_fno.integrate)¶
jax_fno.integrate provides time integration methods for PDEs of the form
$$ \frac{\partial y}{\partial t} = f(t, y). $$
jax_fno.integrate.solve_ivp(fun, t_span, y0, method, step_size, args=())
¶
Integrate dy/dt = fun(t, y, *args) over the time interval t_span.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fun
|
Callable
|
Callable right-hand side of system dy/dt = fun(t, y, *args) |
required |
t_span
|
Tuple[float, float]
|
(t_start, t_end) time interval |
required |
y0
|
Array
|
Initial condition |
required |
method
|
StepperProtocol
|
Time-stepping method instance (e.g., RK4(), BackwardEuler()) |
required |
step_size
|
float
|
Time step size |
required |
args
|
tuple
|
Additional arguments to pass to fun (and jvp/jac if provided) |
()
|
Returns:
| Name | Type | Description |
|---|---|---|
t_final |
float
|
Final time |
y_final |
Array
|
Solution at t_end |
Example usage:
import jax.numpy as jnp
from jax_fno.integrate import solve_ivp, RK4
# Define ODE: dy/dt = -k*y
def fun(t, y, k):
return -k * y
# Solve with args
y0 = jnp.array([1.0])
t_span = (0.0, 2.0)
k = 0.5
t, y = solve_ivp(fun, t_span, y0, RK4(), step_size=0.01, args=(k,))
Example usage with an implicit method and user-defined parameters:
import jax.numpy as jnp
from jax_fno.integrate import solve_ivp, NewtonRaphson, GMRES, BackwardEuler
# Define ODE: dy/dt = -k*y
def fun(t, y, k):
return -k * y
# Create linear solver
linsolver = GMRES(tol=1e-6, maxiter=20)
# Create non-linear solver
root_finder = NewtonRaphson(linsolver=linsolver, tol=1e-5, maxiter=50)
# Choose integration method
method = BackwardEuler(root_finder=root_finder)
# Solve
y0 = jnp.array([1.0])
t_span = (0.0, 2.0)
k = 0.5
t, y = solve_ivp(fun, t_span, y0, method, step_size=0.01, args=(k,))
Source code in src/jax_fno/integrate/solve.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | |
jax_fno.integrate.solve_with_history(fun, t_span, y0, method, step_size, t_eval=None, args=(), verbose=False)
¶
Integrate dy/dt = fun(t, y, *args) over the time interval t_span.
This function allows users to return intermediate states at times t_eval,
but is not compatible with JAX transformations. The integration is done in
chunks by calling solve_ivp, where solve_ivp has been JIT-compiled.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fun
|
Callable
|
Right-hand side function with signature (t, y, *args) -> dydt |
required |
t_span
|
Tuple[float, float]
|
(t_start, t_end) time interval |
required |
y0
|
Array
|
Initial condition |
required |
method
|
StepperProtocol
|
Time-stepping method instance (e.g., RK4(), BackwardEuler()) |
required |
step_size
|
float
|
Time step size for integration. |
required |
t_eval
|
Optional[Array]
|
Times at which to store the computed solution. If None, returns only the initial and final states. Must be sorted and lie within t_span. |
None
|
args
|
tuple
|
Additional arguments to pass to fun (and jvp/jac if provided) |
()
|
verbose
|
bool
|
Print progress information |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
t |
Array
|
Array of time points, shape (n_points,) |
y |
Array
|
Array of solution values at times t, shape (n_points, *y0.shape) |
Example usage:
import jax.numpy as jnp
from jax_fno.integrate import solve_with_history, RK4
# Define ODE: dy/dt = -k*y
def fun(t, y, k):
return -k * y
# Solve with args
y0 = jnp.array([1.0])
t_span = (0.0, 2.0)
t_eval = jnp.linspace(0, 2, 5)
k = 0.5
t, y = solve_with_history(
fun, t_span, y0, RK4(), step_size=0.01, t_eval=t_eval, args=(k,)
)
Source code in src/jax_fno/integrate/solve.py
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 | |
Time-stepping schemes¶
jax_fno.integrate.ForwardEuler
¶
Bases: Module
Forward Euler method.
Discretisation
Source code in src/jax_fno/integrate/timesteppers/explicit.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | |
step(fun, t, y, h, args=())
¶
Perform a single Forward Euler step.
Computes $$ y_{n+1} = y_n + h f(t_n, y_n, *args). $$
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fun
|
Callable
|
Right-hand side of system dydt = f(t, y, *args). |
required |
t
|
Array
|
Current time. Type: 0-dimensional JAX array. |
required |
y
|
Array
|
Current solution. |
required |
h
|
Array
|
Time step size. Type: 0-dimensional JAX array. |
required |
args
|
tuple
|
Additional arguments to pass to fun. |
()
|
Returns:
| Type | Description |
|---|---|
Array
|
Solution at t + h. |
Source code in src/jax_fno/integrate/timesteppers/explicit.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 | |
jax_fno.integrate.RK4
¶
Bases: Module
Fourth (4th) order Runge-Kutta method.
Implements: StepperProtocol
Source code in src/jax_fno/integrate/timesteppers/explicit.py
44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | |
step(fun, t, y, h, args=())
¶
Perform a single RK4 step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fun
|
Callable
|
Right-hand side of system dy/dt = f(t, y, *args). |
required |
t
|
Array
|
Current time. |
required |
y
|
Array
|
Current solution. |
required |
h
|
Array
|
Time step size. |
required |
args
|
tuple
|
Additional arguments to pass to fun. |
()
|
Returns:
| Type | Description |
|---|---|
Array
|
Solution at t + h. |
Source code in src/jax_fno/integrate/timesteppers/explicit.py
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | |
jax_fno.integrate.BackwardEuler
¶
Bases: Module
Backward Euler time stepper.
Discretisation
If neither jvp nor jac are provided, defaults to root-finding using
automatic differentiation with jax.jvp.
Source code in src/jax_fno/integrate/timesteppers/implicit.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | |
step(fun, t, y, h, args=())
¶
Advance one backward Euler step.
Source code in src/jax_fno/integrate/timesteppers/implicit.py
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 | |
jax_fno.integrate.IMEX
¶
Bases: Module
Implicit-Explicit (IMEX) time-stepping scheme.
Splits the ODE into stiff (implicit) and non-stiff (explicit) parts: dy/dt = f_explicit(t, y) + f_implicit(t, y)
The scheme advances the solution in two steps
- Explicit step: u* = u^n + h * f_explicit(t^n, u^n)
- Implicit step: u^{n+1} = u* + h * f_implicit(t^{n+1}, u^{n+1})
The implicit step is solved via root-finding
R(u^{n+1}) = u^{n+1} - u* - h * f_implicit(t^{n+1}, u^{n+1}) = 0
This formulation allows you to: - Use high-order explicit methods (RK4, etc.) for the non-stiff terms - Use implicit methods (BackwardEuler, etc.) for the stiff terms - Avoid overly restrictive time step constraints from stiff terms
Example
from jax_fno.integrate import (
solve_ivp, IMEX, RK4, BackwardEuler,
NewtonRaphson, Spectral
)
def explicit_term(t, u, ...):
return ...
def implicit_term(t, u, ...):
return ...
# Instantiate solver
solver = IMEX(implicit=BackwardEuler(), explicit=RK4())
# Define ODE as a dict
ode = {'implicit': implicit_term, 'explicit': explicit_term}
# Solve
t, y = solve_ivp(ode, t_span, y0, solver, step_size, args)
Source code in src/jax_fno/integrate/timesteppers/imex.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | |
step(fun, t, y, h, args=())
¶
Advance one IMEX step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fun
|
Union[Callable, Dict[str, Callable]]
|
Either a dict with keys 'implicit' and 'explicit', or a callable. If a dict, fun'explicit' gives the non-stiff term and fun'implicit' gives the stiff term. If a callable, it's treated as the implicit term with zero explicit term. |
required |
t
|
Array
|
Current time. |
required |
y
|
Array
|
Current solution. |
required |
h
|
Array
|
Time step size. |
required |
args
|
tuple
|
Additional arguments to pass to fun. |
()
|
Returns:
| Type | Description |
|---|---|
Array
|
Solution at t + h. |
Source code in src/jax_fno/integrate/timesteppers/imex.py
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | |
jax_fno.integrate.StepperProtocol
¶
Bases: Protocol
Protocol for time-stepping schemes.
Defines the interface for advancing an ODE one time step. Any class implementing a step() method with this signature can be used as a time-stepping method in solve_ivp.
Source code in src/jax_fno/integrate/timesteppers/protocol.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | |
step(fun, t, y, h, args=())
¶
Take a single time step.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
fun
|
Callable
|
Right-hand side function. |
required |
t
|
Array
|
Current time. Type: 0-dimensional JAX array. |
required |
y
|
Array
|
Current solution. |
required |
h
|
Array
|
Time step size. Type: 0-dimensional JAX array. |
required |
args
|
tuple
|
Additional arguments to pass to fun. |
()
|
Returns:
| Type | Description |
|---|---|
Array
|
Solution at t + h. |
Source code in src/jax_fno/integrate/timesteppers/protocol.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | |
Root-finding algorithms¶
jax_fno.integrate.NewtonRaphson
¶
Bases: Module
Newton-Raphson root-finding algorithm.
Iterative update: \(y \leftarrow y - J^{-1}(y) R(y)\)
Implements: RootFinderProtocol
Attributes:
| Name | Type | Description |
|---|---|---|
tol |
Convergence tolerance for residual norm |
|
maxiter |
Maximum number of Newton-Raphson iterations |
|
linsolver |
Linear solver for inner iterations (default: GMRES) |
Source code in src/jax_fno/integrate/rootfinders/newtonraphson.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | |
__call__(residual_fn, y_guess, jvp_fn=None, jac_fn=None)
¶
Find the root of residual_fn(y) = 0 using Newton-Raphson method.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
residual_fn
|
LinearMap
|
Residual function R(y) |
required |
y_guess
|
Array
|
Initial guess |
required |
jvp_fn
|
Optional[JVPConstructor]
|
Matrix-free Jacobian-vector product with signature (y, v) -> J(y)*v |
None
|
jac_fn
|
Optional[JacobianConstructor]
|
Function returning a dense Jacobian matrix with signature y -> J(y) |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Solution y |
Source code in src/jax_fno/integrate/rootfinders/newtonraphson.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | |
jax_fno.integrate.RootFinderProtocol
¶
Bases: Protocol
Protocol for root-finding algorithms.
Defines the interface for finding roots of nonlinear equations. Used by implicit time-stepping schemes to solve the nonlinear systems that arise from implicit discretization.
Source code in src/jax_fno/integrate/rootfinders/protocol.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | |
__call__(residual_fn, y_guess, jvp_fn=None, jac_fn=None)
¶
Find the root of residual_fn(y) = 0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
residual_fn
|
LinearMap
|
Function mapping y -> R(y), where we seek R(y) = 0 |
required |
y_guess
|
Array
|
Initial guess for the solution |
required |
jvp_fn
|
Optional[JVPConstructor]
|
Optional Jacobian-vector product function (y, v) -> J*v |
None
|
jac_fn
|
Optional[JacobianConstructor]
|
Optional dense Jacobian function y -> J |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Solution y such that residual_fn(y) ≈ 0 |
Source code in src/jax_fno/integrate/rootfinders/protocol.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 | |
Linear solvers¶
jax_fno.integrate.GMRES
¶
Bases: Module
Generalised Minimal Residual (GMRES).
Dispatches to jax.scipy.sparse.linalg.gmres.
Suitable for general non-symmetric systems.
Source code in src/jax_fno/integrate/linsolvers/krylov.py
11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | |
__call__(A, b, x0=None)
¶
Solve A*x = b using GMRES.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Union[LinearMap, Array]
|
Dense matrix or linear operator with signature x -> A*x |
required |
b
|
Array
|
Right-hand side vector |
required |
x0
|
Optional[Array]
|
Initial guess vector |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Approximate solution x |
Source code in src/jax_fno/integrate/linsolvers/krylov.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | |
jax_fno.integrate.CG
¶
Bases: Module
Conjugate Gradients (CG).
Dispatches to jax.scipy.sparse.linalg.cg.
Only suitable for symmetric and positive-definite systems.
Implements: LinearSolverProtocol
Attributes:
| Name | Type | Description |
|---|---|---|
tol |
Convergence tolerance for residual norm |
|
maxiter |
Maximum number of iterations |
Source code in src/jax_fno/integrate/linsolvers/krylov.py
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | |
__call__(A, b, x0=None)
¶
Solve A*x = b.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Union[LinearMap, Array]
|
Dense matrix or linear operator with signature x -> A*x |
required |
b
|
Array
|
Right-hand side vector |
required |
x0
|
Optional[Array]
|
Initial guess vector |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Approximate solution x |
Source code in src/jax_fno/integrate/linsolvers/krylov.py
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 | |
jax_fno.integrate.BiCGStab
¶
Bases: Module
Stabilised Biconjugate Gradients (BiCGStab).
Dispatches to jax.scipy.sparse.linalg.bicgstab.
Suitable for non-symmetric systems.
Implements: LinearSolverProtocol
Attributes:
| Name | Type | Description |
|---|---|---|
tol |
Convergence tolerance for residual norm |
|
maxiter |
Maximum number of iterations |
Source code in src/jax_fno/integrate/linsolvers/krylov.py
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | |
__call__(A, b, x0=None)
¶
Solve A*x = b.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Union[LinearMap, Array]
|
Dense matrix or linear operator with signature x -> A*x |
required |
b
|
Array
|
Right-hand side vector |
required |
x0
|
Optional[Array]
|
Initial guess vector |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Approximate solution x |
Source code in src/jax_fno/integrate/linsolvers/krylov.py
105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | |
jax_fno.integrate.DirectDense
¶
Bases: Module
Direct solver for dense linear systems.
Dispatches to jax.numpy.linalg.solve.
Only suitable for small systems where the Jacobian is provided explicitly.
Source code in src/jax_fno/integrate/linsolvers/direct.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | |
__call__(A, b, x0=None)
¶
Solve A*x = b.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Union[LinearMap, Array]
|
Dense matrix |
required |
b
|
Array
|
Right-hand side vector |
required |
x0
|
Optional[Array]
|
Ignored (kept for interface compatibility). Can be None. |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Solution x |
Raises:
| Type | Description |
|---|---|
TypeError
|
If A is a callable (linear operator) instead of a matrix |
Source code in src/jax_fno/integrate/linsolvers/direct.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 | |
:: jax_fno.integrate.Spectral
jax_fno.integrate.LinearSolverProtocol
¶
Bases: Protocol
Protocol for linear solvers.
Defines the interface for solving linear systems of the form A*x = b. Any class implementing a call() method with this signature can be used as a linear solver in implicit integration schemes.
Source code in src/jax_fno/integrate/linsolvers/protocol.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | |
__call__(A, b, x0=None)
¶
Solve the linear system A*x = b.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
A
|
Union[LinearMap, Array]
|
Dense matrix or linear operator with signature x -> A*x |
required |
b
|
Array
|
Right-hand side vector |
required |
x0
|
Optional[Array]
|
Initial guess vector |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
Solution vector x such that A*x ≈ b |
Source code in src/jax_fno/integrate/linsolvers/protocol.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 | |