Skip to content

Integration (jax_fno.integrate)

jax_fno.integrate provides time integration methods for PDEs of the form $$ \frac{\partial y}{\partial t} = f(t, y). $$

jax_fno.integrate.solve_ivp(fun, t_span, y0, method, step_size, args=())

Integrate dy/dt = fun(t, y, *args) over the time interval t_span.

Parameters:

Name Type Description Default
fun Callable

Callable right-hand side of system dy/dt = fun(t, y, *args)

required
t_span Tuple[float, float]

(t_start, t_end) time interval

required
y0 Array

Initial condition

required
method StepperProtocol

Time-stepping method instance (e.g., RK4(), BackwardEuler())

required
step_size float

Time step size

required
args tuple

Additional arguments to pass to fun (and jvp/jac if provided)

()

Returns:

Name Type Description
t_final float

Final time

y_final Array

Solution at t_end

Example usage:

import jax.numpy as jnp
from jax_fno.integrate import solve_ivp, RK4

# Define ODE: dy/dt = -k*y
def fun(t, y, k):
    return -k * y

# Solve with args
y0 = jnp.array([1.0])
t_span = (0.0, 2.0)
k = 0.5

t, y = solve_ivp(fun, t_span, y0, RK4(), step_size=0.01, args=(k,))

Example usage with an implicit method and user-defined parameters:

import jax.numpy as jnp
from jax_fno.integrate import solve_ivp, NewtonRaphson, GMRES, BackwardEuler

# Define ODE: dy/dt = -k*y
def fun(t, y, k):
    return -k * y

# Create linear solver
linsolver = GMRES(tol=1e-6, maxiter=20)

# Create non-linear solver
root_finder = NewtonRaphson(linsolver=linsolver, tol=1e-5, maxiter=50)

# Choose integration method
method = BackwardEuler(root_finder=root_finder)

# Solve
y0 = jnp.array([1.0])
t_span = (0.0, 2.0)
k = 0.5
t, y = solve_ivp(fun, t_span, y0, method, step_size=0.01, args=(k,))

Source code in src/jax_fno/integrate/solve.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def solve_ivp(
    fun: Callable,
    t_span: Tuple[float, float],
    y0: Array,
    method: StepperProtocol,
    step_size: float,
    args: tuple = ()
) -> Tuple[float, Array]:
    """
    Integrate dy/dt = fun(t, y, *args) over the time interval t_span.

    Args:
        fun: Callable right-hand side of system dy/dt = fun(t, y, *args)
        t_span: (t_start, t_end) time interval
        y0: Initial condition
        method: Time-stepping method instance (e.g., RK4(), BackwardEuler())
        step_size: Time step size
        args: Additional arguments to pass to fun (and jvp/jac if provided)

    Returns:
        t_final: Final time
        y_final: Solution at t_end

    Example usage:
    ```python
    import jax.numpy as jnp
    from jax_fno.integrate import solve_ivp, RK4

    # Define ODE: dy/dt = -k*y
    def fun(t, y, k):
        return -k * y

    # Solve with args
    y0 = jnp.array([1.0])
    t_span = (0.0, 2.0)
    k = 0.5

    t, y = solve_ivp(fun, t_span, y0, RK4(), step_size=0.01, args=(k,))
    ```

    Example usage with an implicit method and user-defined parameters:
    ```python
    import jax.numpy as jnp
    from jax_fno.integrate import solve_ivp, NewtonRaphson, GMRES, BackwardEuler

    # Define ODE: dy/dt = -k*y
    def fun(t, y, k):
        return -k * y

    # Create linear solver
    linsolver = GMRES(tol=1e-6, maxiter=20)

    # Create non-linear solver
    root_finder = NewtonRaphson(linsolver=linsolver, tol=1e-5, maxiter=50)

    # Choose integration method
    method = BackwardEuler(root_finder=root_finder)

    # Solve
    y0 = jnp.array([1.0])
    t_span = (0.0, 2.0)
    k = 0.5
    t, y = solve_ivp(fun, t_span, y0, method, step_size=0.01, args=(k,))
    ```
    """
    t_start, t_end = t_span

    def cond_fn(carry):
        t, y, _ = carry
        return t < t_end

    def body_fn(carry):
        t, y, m = carry

        # Adjust final step to hit t_end exactly
        h = jax.lax.max(0.0, jax.lax.min(step_size, t_end - t))

        y_next = m.step(fun, t, y, h, args)

        t_next = t + h

        return (t_next, y_next, m)

    t_final, y_final, _ = jax.lax.while_loop(cond_fn, body_fn, (t_start, y0, method))

    return t_final, y_final

jax_fno.integrate.solve_with_history(fun, t_span, y0, method, step_size, t_eval=None, args=(), verbose=False)

Integrate dy/dt = fun(t, y, *args) over the time interval t_span.

This function allows users to return intermediate states at times t_eval, but is not compatible with JAX transformations. The integration is done in chunks by calling solve_ivp, where solve_ivp has been JIT-compiled.

Parameters:

Name Type Description Default
fun Callable

Right-hand side function with signature (t, y, *args) -> dydt

required
t_span Tuple[float, float]

(t_start, t_end) time interval

required
y0 Array

Initial condition

required
method StepperProtocol

Time-stepping method instance (e.g., RK4(), BackwardEuler())

required
step_size float

Time step size for integration.

required
t_eval Optional[Array]

Times at which to store the computed solution. If None, returns only the initial and final states. Must be sorted and lie within t_span.

None
args tuple

Additional arguments to pass to fun (and jvp/jac if provided)

()
verbose bool

Print progress information

False

Returns:

Name Type Description
t Array

Array of time points, shape (n_points,)

y Array

Array of solution values at times t, shape (n_points, *y0.shape)

Example usage:

import jax.numpy as jnp
from jax_fno.integrate import solve_with_history, RK4

# Define ODE: dy/dt = -k*y
def fun(t, y, k):
    return -k * y

# Solve with args
y0 = jnp.array([1.0])
t_span = (0.0, 2.0)
t_eval = jnp.linspace(0, 2, 5)
k = 0.5

t, y = solve_with_history(
    fun, t_span, y0, RK4(), step_size=0.01, t_eval=t_eval, args=(k,)
)

Source code in src/jax_fno/integrate/solve.py
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def solve_with_history(
    fun: Callable,
    t_span: Tuple[float, float],
    y0: Array,
    method: StepperProtocol,
    step_size: float,
    t_eval: Optional[Array] = None,
    args: tuple = (),
    verbose: bool = False
) -> Tuple[Array, Array]:
    """
    Integrate dy/dt = fun(t, y, *args) over the time interval t_span.

    This function allows users to return intermediate states at times `t_eval`,
    but is not compatible with JAX transformations. The integration is done in 
    chunks by calling `solve_ivp`, where `solve_ivp` has been JIT-compiled.

    Args:
        fun: Right-hand side function with signature (t, y, *args) -> dydt
        t_span: (t_start, t_end) time interval
        y0: Initial condition
        method: Time-stepping method instance (e.g., RK4(), BackwardEuler())
        step_size: Time step size for integration.
        t_eval: Times at which to store the computed solution.
            If None, returns only the initial and final states.
            Must be sorted and lie within t_span.
        args: Additional arguments to pass to fun (and jvp/jac if provided)
        verbose: Print progress information

    Returns:
        t: Array of time points, shape (n_points,)
        y: Array of solution values at times t, shape (n_points, *y0.shape)

    Example usage:
    ```python
    import jax.numpy as jnp
    from jax_fno.integrate import solve_with_history, RK4

    # Define ODE: dy/dt = -k*y
    def fun(t, y, k):
        return -k * y

    # Solve with args
    y0 = jnp.array([1.0])
    t_span = (0.0, 2.0)
    t_eval = jnp.linspace(0, 2, 5)
    k = 0.5

    t, y = solve_with_history(
        fun, t_span, y0, RK4(), step_size=0.01, t_eval=t_eval, args=(k,)
    )
    ```
    """
    t_start, t_end = t_span

    # Set up evaluation times
    if t_eval is None:
        # Only save initial and final states
        t_eval = jnp.array([t_start, t_end])
    else:
        # Validate t_eval
        t_eval = jnp.asarray(t_eval)
        if jnp.any(t_eval < t_start) or jnp.any(t_eval > t_end):
            raise ValueError("All values in t_eval must be within t_span")
        if jnp.any(jnp.diff(t_eval) < 0):
            raise ValueError("t_eval must be sorted in increasing order")

        # Ensure t_start is included
        if t_eval[0] != t_start:
            t_eval = jnp.concatenate([jnp.array([t_start]), t_eval])

    n_steps_total = int(jnp.ceil((t_end - t_start) / step_size))

    if verbose:
        method_name = type(method).__name__
        print(f"Solving with {method_name}")
        print(
            f"Time: [{t_start}, {t_end}], dt={step_size}, "
            f"~{n_steps_total} total steps"
        )
        print(f"Evaluating at {len(t_eval)} time points")

    integrate_jit = jax.jit(solve_ivp, static_argnames=['fun', 'method'])

    # Integrate between consecutive evaluation points:
    y_save = [y0]
    t_save = [t_start]
    y = y0

    start_wallclock = time.time()

    for i in range(len(t_eval) - 1):
        t_i = float(t_eval[i])
        t_ip1 = float(t_eval[i + 1])
        t, y = integrate_jit(fun, (t_i, t_ip1), y, method, step_size, args)
        t_save.append(t)
        y_save.append(y)

    t_arr = jnp.stack(t_save)
    y_arr = jnp.stack(y_save, axis=0)

    elapsed_wallclock = time.time() - start_wallclock

    if verbose:
        print(
            f"Completed in {elapsed_wallclock:.3f}s "
            f"({n_steps_total / elapsed_wallclock:.1f} steps/s)"
        )

    return t_arr, y_arr

Time-stepping schemes

jax_fno.integrate.ForwardEuler

Bases: Module

Forward Euler method.

Discretisation
\[ \frac{\partial y}{\partial t} \rightarrow \frac{(y_{n+1} - y_n)}{h} = f(t_n, y_n) \]
Source code in src/jax_fno/integrate/timesteppers/explicit.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class ForwardEuler(nnx.Module):
    """
    Forward Euler method.

    Discretisation:
        $$ \\frac{\\partial y}{\\partial t} \\rightarrow
        \\frac{(y_{n+1} - y_n)}{h} = f(t_n, y_n) $$
    """

    def step(
        self,
        fun: Callable,
        t: Array,
        y: Array,
        h: Array,
        args: tuple = ()
    ) -> Array:
        """
        Perform a single Forward Euler step.

        Computes $$ y_{n+1} = y_n + h f(t_n, y_n, *args). $$

        Args:
            fun: Right-hand side of system dydt = f(t, y, *args).
            t: Current time. Type: 0-dimensional JAX array.
            y: Current solution.
            h: Time step size. Type: 0-dimensional JAX array.
            args: Additional arguments to pass to fun.

        Returns:
            Solution at t + h.
        """
        return y + h * fun(t, y, *args)

step(fun, t, y, h, args=())

Perform a single Forward Euler step.

Computes $$ y_{n+1} = y_n + h f(t_n, y_n, *args). $$

Parameters:

Name Type Description Default
fun Callable

Right-hand side of system dydt = f(t, y, *args).

required
t Array

Current time. Type: 0-dimensional JAX array.

required
y Array

Current solution.

required
h Array

Time step size. Type: 0-dimensional JAX array.

required
args tuple

Additional arguments to pass to fun.

()

Returns:

Type Description
Array

Solution at t + h.

Source code in src/jax_fno/integrate/timesteppers/explicit.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def step(
    self,
    fun: Callable,
    t: Array,
    y: Array,
    h: Array,
    args: tuple = ()
) -> Array:
    """
    Perform a single Forward Euler step.

    Computes $$ y_{n+1} = y_n + h f(t_n, y_n, *args). $$

    Args:
        fun: Right-hand side of system dydt = f(t, y, *args).
        t: Current time. Type: 0-dimensional JAX array.
        y: Current solution.
        h: Time step size. Type: 0-dimensional JAX array.
        args: Additional arguments to pass to fun.

    Returns:
        Solution at t + h.
    """
    return y + h * fun(t, y, *args)

jax_fno.integrate.RK4

Bases: Module

Fourth (4th) order Runge-Kutta method.

Implements: StepperProtocol

Source code in src/jax_fno/integrate/timesteppers/explicit.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class RK4(nnx.Module):
    """
    Fourth (4th) order Runge-Kutta method.

    Implements: StepperProtocol
    """

    def step(
        self,
        fun: Callable,
        t: Array,
        y: Array,
        h: Array,
        args: tuple = ()
    ) -> Array:
        """
        Perform a single RK4 step.

        Args:
            fun: Right-hand side of system dy/dt = f(t, y, *args).
            t: Current time.
            y: Current solution.
            h: Time step size.
            args: Additional arguments to pass to fun.

        Returns:
            Solution at t + h.
        """
        k1 = fun(t, y, *args)
        k2 = fun(t + 0.5 * h, y + 0.5 * h * k1, *args)
        k3 = fun(t + 0.5 * h, y + 0.5 * h * k2, *args)
        k4 = fun(t + h, y + h * k3, *args)
        return y + (h / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)

step(fun, t, y, h, args=())

Perform a single RK4 step.

Parameters:

Name Type Description Default
fun Callable

Right-hand side of system dy/dt = f(t, y, *args).

required
t Array

Current time.

required
y Array

Current solution.

required
h Array

Time step size.

required
args tuple

Additional arguments to pass to fun.

()

Returns:

Type Description
Array

Solution at t + h.

Source code in src/jax_fno/integrate/timesteppers/explicit.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def step(
    self,
    fun: Callable,
    t: Array,
    y: Array,
    h: Array,
    args: tuple = ()
) -> Array:
    """
    Perform a single RK4 step.

    Args:
        fun: Right-hand side of system dy/dt = f(t, y, *args).
        t: Current time.
        y: Current solution.
        h: Time step size.
        args: Additional arguments to pass to fun.

    Returns:
        Solution at t + h.
    """
    k1 = fun(t, y, *args)
    k2 = fun(t + 0.5 * h, y + 0.5 * h * k1, *args)
    k3 = fun(t + 0.5 * h, y + 0.5 * h * k2, *args)
    k4 = fun(t + h, y + h * k3, *args)
    return y + (h / 6.0) * (k1 + 2.0 * k2 + 2.0 * k3 + k4)

jax_fno.integrate.BackwardEuler

Bases: Module

Backward Euler time stepper.

Discretisation
\[ \frac{\partial y}{\partial t} \rightarrow \frac{(y_{n+1} - y_n)}{h} = f(t_{n+1}, y_{n+1}) \]

If neither jvp nor jac are provided, defaults to root-finding using automatic differentiation with jax.jvp.

Source code in src/jax_fno/integrate/timesteppers/implicit.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class BackwardEuler(nnx.Module):
    """
    Backward Euler time stepper.

    Discretisation:
        $$ \\frac{\\partial y}{\\partial t} \\rightarrow
        \\frac{(y_{n+1} - y_n)}{h} = f(t_{n+1}, y_{n+1}) $$

    If neither jvp nor jac are provided, defaults to root-finding using
    automatic differentiation with `jax.jvp`.
    """

    def __init__(
        self,
        root_finder: RootFinderProtocol = NewtonRaphson(),
        jvp: Optional[Callable] = None,
        jac: Optional[Callable] = None,
    ):
        self.root_finder = root_finder
        self.jvp = jvp
        self.jac = jac

    def _residual(self, fun, t, y, h, args):
        return lambda y_np1: y_np1 - y - h * fun(t + h, y_np1, *args)

    def _build_linearisation(self, fun, t, h, args):
        """Return (jac_fn, jvp_fn) for the Jacobian J_f = I - h * df/dy."""

        if self.jac is not None:
            # Dense Jacobian supplied
            def jac_fn(y):
                return jnp.eye(y.size) - h * self.jac(t + h, y, *args)
            return jac_fn, None

        if self.jvp is not None:
            # Matrix-free Jacobian supplied
            def jvp_fn(y, v):
                return v - h * self.jvp(t + h, y, v, *args)
            return None, jvp_fn

        def jvp_autodiff(y, v):
            # Fall back to automatic differentiation
            _, df_v = jax.jvp(lambda y_: fun(t + h, y_, *args), (y,), (v,))
            return v - h * df_v

        return None, jvp_autodiff

    def step(self, fun, t: Array, y: Array, h: Array, args=()):
        """Advance one backward Euler step."""

        # If using a spectral solver, construct diagonal symbol (1 - h s(k))
        if hasattr(self.root_finder, 'linsolver'):
            linsolver = self.root_finder.linsolver
            if hasattr(linsolver, 'eigvals'):
                linsolver.set_symbol(1.0 - h * linsolver.eigvals)

        residual = self._residual(fun, t, y, h, args)

        # cheap forward Euler guess
        y0 = y + h * fun(t, y, *args)

        jac_fn, jvp_fn = self._build_linearisation(fun, t, h, args)

        return self.root_finder(residual, y0, jac_fn=jac_fn, jvp_fn=jvp_fn)

step(fun, t, y, h, args=())

Advance one backward Euler step.

Source code in src/jax_fno/integrate/timesteppers/implicit.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def step(self, fun, t: Array, y: Array, h: Array, args=()):
    """Advance one backward Euler step."""

    # If using a spectral solver, construct diagonal symbol (1 - h s(k))
    if hasattr(self.root_finder, 'linsolver'):
        linsolver = self.root_finder.linsolver
        if hasattr(linsolver, 'eigvals'):
            linsolver.set_symbol(1.0 - h * linsolver.eigvals)

    residual = self._residual(fun, t, y, h, args)

    # cheap forward Euler guess
    y0 = y + h * fun(t, y, *args)

    jac_fn, jvp_fn = self._build_linearisation(fun, t, h, args)

    return self.root_finder(residual, y0, jac_fn=jac_fn, jvp_fn=jvp_fn)

jax_fno.integrate.IMEX

Bases: Module

Implicit-Explicit (IMEX) time-stepping scheme.

Splits the ODE into stiff (implicit) and non-stiff (explicit) parts: dy/dt = f_explicit(t, y) + f_implicit(t, y)

The scheme advances the solution in two steps
  1. Explicit step: u* = u^n + h * f_explicit(t^n, u^n)
  2. Implicit step: u^{n+1} = u* + h * f_implicit(t^{n+1}, u^{n+1})
The implicit step is solved via root-finding

R(u^{n+1}) = u^{n+1} - u* - h * f_implicit(t^{n+1}, u^{n+1}) = 0

This formulation allows you to: - Use high-order explicit methods (RK4, etc.) for the non-stiff terms - Use implicit methods (BackwardEuler, etc.) for the stiff terms - Avoid overly restrictive time step constraints from stiff terms

Example
from jax_fno.integrate import (
    solve_ivp, IMEX, RK4, BackwardEuler,
    NewtonRaphson, Spectral
)

def explicit_term(t, u, ...):
    return ...

def implicit_term(t, u, ...):
    return ...

# Instantiate solver
solver = IMEX(implicit=BackwardEuler(), explicit=RK4())

# Define ODE as a dict
ode = {'implicit': implicit_term, 'explicit': explicit_term}

# Solve
t, y = solve_ivp(ode, t_span, y0, solver, step_size, args)
Source code in src/jax_fno/integrate/timesteppers/imex.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class IMEX(nnx.Module):
    """
    Implicit-Explicit (IMEX) time-stepping scheme.

    Splits the ODE into stiff (implicit) and non-stiff (explicit) parts:
        dy/dt = f_explicit(t, y) + f_implicit(t, y)

    The scheme advances the solution in two steps:
        1. Explicit step:  u* = u^n + h * f_explicit(t^n, u^n)
        2. Implicit step:  u^{n+1} = u* + h * f_implicit(t^{n+1}, u^{n+1})

    The implicit step is solved via root-finding:
        R(u^{n+1}) = u^{n+1} - u* - h * f_implicit(t^{n+1}, u^{n+1}) = 0

    This formulation allows you to:
    - Use high-order explicit methods (RK4, etc.) for the non-stiff terms
    - Use implicit methods (BackwardEuler, etc.) for the stiff terms
    - Avoid overly restrictive time step constraints from stiff terms

    Example:
        ```python
        from jax_fno.integrate import (
            solve_ivp, IMEX, RK4, BackwardEuler,
            NewtonRaphson, Spectral
        )

        def explicit_term(t, u, ...):
            return ...

        def implicit_term(t, u, ...):
            return ...

        # Instantiate solver
        solver = IMEX(implicit=BackwardEuler(), explicit=RK4())

        # Define ODE as a dict
        ode = {'implicit': implicit_term, 'explicit': explicit_term}

        # Solve
        t, y = solve_ivp(ode, t_span, y0, solver, step_size, args)
        ```
    """

    def __init__(
        self,
        implicit: StepperProtocol,
        explicit: StepperProtocol,
    ):
        self.implicit = implicit
        self.explicit = explicit

    def step(
        self,
        fun: Union[Callable, Dict[str, Callable]],
        t: Array,
        y: Array,
        h: Array,
        args: tuple = ()
    ) -> Array:
        """
        Advance one IMEX step.

        Args:
            fun: Either a dict with keys 'implicit' and 'explicit', or a callable.
                If a dict, fun['explicit'](t, y, *args) gives the non-stiff term
                and fun['implicit'](t, y, *args) gives the stiff term.
                If a callable, it's treated as the implicit term with zero explicit term.
            t: Current time.
            y: Current solution.
            h: Time step size.
            args: Additional arguments to pass to fun.

        Returns:
            Solution at t + h.
        """
        # Handle dict-based interface
        if isinstance(fun, dict):
            if 'explicit' not in fun or 'implicit' not in fun:
                raise ValueError(
                    "IMEX requires fun to be a dict with 'explicit' and 'implicit' keys, "
                    f"but got keys: {list(fun.keys())}"
                )
            fun_explicit = fun['explicit']
            fun_implicit = fun['implicit']
        else:
            # If fun is a callable, treat it as implicit-only
            fun_explicit = lambda t, y, *args: 0.0
            fun_implicit = fun

        # Step 1: Explicit advance to get intermediate state u*
        # u* = u^n + h * f_explicit(t^n, u^n)
        u_star = self.explicit.step(fun_explicit, t, y, h, args)

        # Step 2: Implicit solve from u* to get u^{n+1}
        # Solve: u^{n+1} = u* + h * f_implicit(t^{n+1}, u^{n+1})
        #
        # The implicit stepper expects to solve:
        #   u^{n+1} = u^n + h * f(t^{n+1}, u^{n+1})
        #
        # We substitute u^n -> u* to get the correct equation
        u_next = self.implicit.step(fun_implicit, t, u_star, h, args)

        return u_next

step(fun, t, y, h, args=())

Advance one IMEX step.

Parameters:

Name Type Description Default
fun Union[Callable, Dict[str, Callable]]

Either a dict with keys 'implicit' and 'explicit', or a callable. If a dict, fun'explicit' gives the non-stiff term and fun'implicit' gives the stiff term. If a callable, it's treated as the implicit term with zero explicit term.

required
t Array

Current time.

required
y Array

Current solution.

required
h Array

Time step size.

required
args tuple

Additional arguments to pass to fun.

()

Returns:

Type Description
Array

Solution at t + h.

Source code in src/jax_fno/integrate/timesteppers/imex.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def step(
    self,
    fun: Union[Callable, Dict[str, Callable]],
    t: Array,
    y: Array,
    h: Array,
    args: tuple = ()
) -> Array:
    """
    Advance one IMEX step.

    Args:
        fun: Either a dict with keys 'implicit' and 'explicit', or a callable.
            If a dict, fun['explicit'](t, y, *args) gives the non-stiff term
            and fun['implicit'](t, y, *args) gives the stiff term.
            If a callable, it's treated as the implicit term with zero explicit term.
        t: Current time.
        y: Current solution.
        h: Time step size.
        args: Additional arguments to pass to fun.

    Returns:
        Solution at t + h.
    """
    # Handle dict-based interface
    if isinstance(fun, dict):
        if 'explicit' not in fun or 'implicit' not in fun:
            raise ValueError(
                "IMEX requires fun to be a dict with 'explicit' and 'implicit' keys, "
                f"but got keys: {list(fun.keys())}"
            )
        fun_explicit = fun['explicit']
        fun_implicit = fun['implicit']
    else:
        # If fun is a callable, treat it as implicit-only
        fun_explicit = lambda t, y, *args: 0.0
        fun_implicit = fun

    # Step 1: Explicit advance to get intermediate state u*
    # u* = u^n + h * f_explicit(t^n, u^n)
    u_star = self.explicit.step(fun_explicit, t, y, h, args)

    # Step 2: Implicit solve from u* to get u^{n+1}
    # Solve: u^{n+1} = u* + h * f_implicit(t^{n+1}, u^{n+1})
    #
    # The implicit stepper expects to solve:
    #   u^{n+1} = u^n + h * f(t^{n+1}, u^{n+1})
    #
    # We substitute u^n -> u* to get the correct equation
    u_next = self.implicit.step(fun_implicit, t, u_star, h, args)

    return u_next

jax_fno.integrate.StepperProtocol

Bases: Protocol

Protocol for time-stepping schemes.

Defines the interface for advancing an ODE one time step. Any class implementing a step() method with this signature can be used as a time-stepping method in solve_ivp.

Source code in src/jax_fno/integrate/timesteppers/protocol.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
@runtime_checkable
class StepperProtocol(Protocol):
    """
    Protocol for time-stepping schemes.

    Defines the interface for advancing an ODE one time step.
    Any class implementing a step() method with this signature can be used
    as a time-stepping method in solve_ivp.
    """

    def step(
        self,
        fun: Callable,
        t: Array,
        y: Array,
        h: Array,
        args: tuple = ()
    ) -> Array:
        """
        Take a single time step.

        Args:
            fun: Right-hand side function.
            t: Current time. Type: 0-dimensional JAX array.
            y: Current solution.
            h: Time step size. Type: 0-dimensional JAX array.
            args: Additional arguments to pass to fun.

        Returns:
            Solution at t + h.
        """
        ...

step(fun, t, y, h, args=())

Take a single time step.

Parameters:

Name Type Description Default
fun Callable

Right-hand side function.

required
t Array

Current time. Type: 0-dimensional JAX array.

required
y Array

Current solution.

required
h Array

Time step size. Type: 0-dimensional JAX array.

required
args tuple

Additional arguments to pass to fun.

()

Returns:

Type Description
Array

Solution at t + h.

Source code in src/jax_fno/integrate/timesteppers/protocol.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
def step(
    self,
    fun: Callable,
    t: Array,
    y: Array,
    h: Array,
    args: tuple = ()
) -> Array:
    """
    Take a single time step.

    Args:
        fun: Right-hand side function.
        t: Current time. Type: 0-dimensional JAX array.
        y: Current solution.
        h: Time step size. Type: 0-dimensional JAX array.
        args: Additional arguments to pass to fun.

    Returns:
        Solution at t + h.
    """
    ...

Root-finding algorithms

jax_fno.integrate.NewtonRaphson

Bases: Module

Newton-Raphson root-finding algorithm.

Iterative update: \(y \leftarrow y - J^{-1}(y) R(y)\)

Implements: RootFinderProtocol

Attributes:

Name Type Description
tol

Convergence tolerance for residual norm

maxiter

Maximum number of Newton-Raphson iterations

linsolver

Linear solver for inner iterations (default: GMRES)

Source code in src/jax_fno/integrate/rootfinders/newtonraphson.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class NewtonRaphson(nnx.Module):
    """
    Newton-Raphson root-finding algorithm.

    Iterative update: $y \\leftarrow y - J^{-1}(y) R(y)$

    Implements: RootFinderProtocol

    Attributes:
        tol: Convergence tolerance for residual norm
        maxiter: Maximum number of Newton-Raphson iterations
        linsolver: Linear solver for inner iterations (default: GMRES)
    """

    def __init__(
        self,
        tol: float = 1e-6,
        maxiter: int = 50,
        linsolver: LinearSolverProtocol = GMRES()
    ):
        self.tol = tol
        self.maxiter = maxiter
        self.linsolver = linsolver

    def __call__(
        self,
        residual_fn: LinearMap,
        y_guess: Array,
        jvp_fn: Optional[JVPConstructor] = None,
        jac_fn: Optional[JacobianConstructor] = None,
    ) -> Array:
        """
        Find the root of residual_fn(y) = 0 using Newton-Raphson method.

        Args:
            residual_fn: Residual function R(y)
            y_guess: Initial guess
            jvp_fn: Matrix-free Jacobian-vector product with
                signature (y, v) -> J(y)*v
            jac_fn: Function returning a dense Jacobian matrix with
                signature y -> J(y)

        Returns:
            Solution y
        """
        if jac_fn is not None and jvp_fn is not None:
            raise ValueError("Provide either jac_fn OR jvp_fn, not both.")

        # State carried through Newton iterations
        y_k = y_guess
        r_k = residual_fn(y_k)
        state0 = (y_k, r_k, 0)

        # Define body for a single Newton iteration
        if jac_fn is not None:
            # Dense mode
            def body_fun(state):
                y_k, r_k, k = state
                J = jac_fn(y_k)  # construct dense matrix
                delta = self.linsolver(J, -r_k)
                y_kp1 = y_k + delta
                r_kp1 = residual_fn(y_kp1)
                return (y_kp1, r_kp1, k + 1)
        elif jvp_fn is not None:
            # Matrix-free mode
            def body_fun(state):
                y_k, r_k, k = state
                jvp = lambda v : jvp_fn(y_k, v)  # define matrix-free operator
                delta = self.linsolver(jvp, -r_k)
                y_kp1 = y_k + delta
                r_kp1 = residual_fn(y_kp1)
                return (y_kp1, r_kp1, k + 1)
        else:
            raise ValueError("Must provide either jvp_fn or jac_fn")

        # Convergence condition
        def cond_fun(state):
            _, r_k, k = state
            return (jnp.linalg.norm(r_k) > self.tol) & (k < self.maxiter)

        y_final, r_final, niters = jax.lax.while_loop(cond_fun, body_fun, state0)

        # Runtime warning
        def warn_callback(iters, maxiter, residual_norm, tol):
            if iters >= maxiter and residual_norm > tol:
                s1 = f"WARNING: Newton-Raphson did not converge within {int(maxiter)} iterations."
                s2 = f"Final residual norm: {float(residual_norm):.2e}."
                print(s1 + "\n" + s2)

        jax.debug.callback(
            warn_callback, 
            niters, self.maxiter, 
            jnp.linalg.norm(r_final), self.tol
        )

        return y_final

__call__(residual_fn, y_guess, jvp_fn=None, jac_fn=None)

Find the root of residual_fn(y) = 0 using Newton-Raphson method.

Parameters:

Name Type Description Default
residual_fn LinearMap

Residual function R(y)

required
y_guess Array

Initial guess

required
jvp_fn Optional[JVPConstructor]

Matrix-free Jacobian-vector product with signature (y, v) -> J(y)*v

None
jac_fn Optional[JacobianConstructor]

Function returning a dense Jacobian matrix with signature y -> J(y)

None

Returns:

Type Description
Array

Solution y

Source code in src/jax_fno/integrate/rootfinders/newtonraphson.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def __call__(
    self,
    residual_fn: LinearMap,
    y_guess: Array,
    jvp_fn: Optional[JVPConstructor] = None,
    jac_fn: Optional[JacobianConstructor] = None,
) -> Array:
    """
    Find the root of residual_fn(y) = 0 using Newton-Raphson method.

    Args:
        residual_fn: Residual function R(y)
        y_guess: Initial guess
        jvp_fn: Matrix-free Jacobian-vector product with
            signature (y, v) -> J(y)*v
        jac_fn: Function returning a dense Jacobian matrix with
            signature y -> J(y)

    Returns:
        Solution y
    """
    if jac_fn is not None and jvp_fn is not None:
        raise ValueError("Provide either jac_fn OR jvp_fn, not both.")

    # State carried through Newton iterations
    y_k = y_guess
    r_k = residual_fn(y_k)
    state0 = (y_k, r_k, 0)

    # Define body for a single Newton iteration
    if jac_fn is not None:
        # Dense mode
        def body_fun(state):
            y_k, r_k, k = state
            J = jac_fn(y_k)  # construct dense matrix
            delta = self.linsolver(J, -r_k)
            y_kp1 = y_k + delta
            r_kp1 = residual_fn(y_kp1)
            return (y_kp1, r_kp1, k + 1)
    elif jvp_fn is not None:
        # Matrix-free mode
        def body_fun(state):
            y_k, r_k, k = state
            jvp = lambda v : jvp_fn(y_k, v)  # define matrix-free operator
            delta = self.linsolver(jvp, -r_k)
            y_kp1 = y_k + delta
            r_kp1 = residual_fn(y_kp1)
            return (y_kp1, r_kp1, k + 1)
    else:
        raise ValueError("Must provide either jvp_fn or jac_fn")

    # Convergence condition
    def cond_fun(state):
        _, r_k, k = state
        return (jnp.linalg.norm(r_k) > self.tol) & (k < self.maxiter)

    y_final, r_final, niters = jax.lax.while_loop(cond_fun, body_fun, state0)

    # Runtime warning
    def warn_callback(iters, maxiter, residual_norm, tol):
        if iters >= maxiter and residual_norm > tol:
            s1 = f"WARNING: Newton-Raphson did not converge within {int(maxiter)} iterations."
            s2 = f"Final residual norm: {float(residual_norm):.2e}."
            print(s1 + "\n" + s2)

    jax.debug.callback(
        warn_callback, 
        niters, self.maxiter, 
        jnp.linalg.norm(r_final), self.tol
    )

    return y_final

jax_fno.integrate.RootFinderProtocol

Bases: Protocol

Protocol for root-finding algorithms.

Defines the interface for finding roots of nonlinear equations. Used by implicit time-stepping schemes to solve the nonlinear systems that arise from implicit discretization.

Source code in src/jax_fno/integrate/rootfinders/protocol.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
@runtime_checkable
class RootFinderProtocol(Protocol):
    """
    Protocol for root-finding algorithms.

    Defines the interface for finding roots of nonlinear equations.
    Used by implicit time-stepping schemes to solve the nonlinear systems
    that arise from implicit discretization.
    """

    def __call__(
        self,
        residual_fn: LinearMap,
        y_guess: Array,
        jvp_fn: Optional[JVPConstructor] = None,
        jac_fn: Optional[JacobianConstructor] = None,
    ) -> Array:
        """
        Find the root of residual_fn(y) = 0.

        Args:
            residual_fn: Function mapping y -> R(y), where we seek R(y) = 0
            y_guess: Initial guess for the solution
            jvp_fn: Optional Jacobian-vector product function (y, v) -> J*v
            jac_fn: Optional dense Jacobian function y -> J

        Returns:
            Solution y such that residual_fn(y) ≈ 0
        """
        ...

__call__(residual_fn, y_guess, jvp_fn=None, jac_fn=None)

Find the root of residual_fn(y) = 0.

Parameters:

Name Type Description Default
residual_fn LinearMap

Function mapping y -> R(y), where we seek R(y) = 0

required
y_guess Array

Initial guess for the solution

required
jvp_fn Optional[JVPConstructor]

Optional Jacobian-vector product function (y, v) -> J*v

None
jac_fn Optional[JacobianConstructor]

Optional dense Jacobian function y -> J

None

Returns:

Type Description
Array

Solution y such that residual_fn(y) ≈ 0

Source code in src/jax_fno/integrate/rootfinders/protocol.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
def __call__(
    self,
    residual_fn: LinearMap,
    y_guess: Array,
    jvp_fn: Optional[JVPConstructor] = None,
    jac_fn: Optional[JacobianConstructor] = None,
) -> Array:
    """
    Find the root of residual_fn(y) = 0.

    Args:
        residual_fn: Function mapping y -> R(y), where we seek R(y) = 0
        y_guess: Initial guess for the solution
        jvp_fn: Optional Jacobian-vector product function (y, v) -> J*v
        jac_fn: Optional dense Jacobian function y -> J

    Returns:
        Solution y such that residual_fn(y) ≈ 0
    """
    ...

Linear solvers

jax_fno.integrate.GMRES

Bases: Module

Generalised Minimal Residual (GMRES).

Dispatches to jax.scipy.sparse.linalg.gmres. Suitable for general non-symmetric systems.

Source code in src/jax_fno/integrate/linsolvers/krylov.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class GMRES(nnx.Module):
    """
    Generalised Minimal Residual (GMRES).

    Dispatches to `jax.scipy.sparse.linalg.gmres`.
    Suitable for general non-symmetric systems.
    """

    def __init__(self, tol: float = 1e-6, maxiter: int = 100):
        self.tol = tol
        self.maxiter = maxiter

    def __call__(
        self, 
        A: Union[LinearMap, Array], 
        b: Array,
        x0: Optional[Array] = None
    ) -> Array:
        """
        Solve A*x = b using GMRES.

        Args:
            A: Dense matrix or linear operator with signature x -> A*x
            b: Right-hand side vector
            x0: Initial guess vector

        Returns:
            Approximate solution x
        """
        solution, info = jax_sparse.gmres(
            A, b, x0=x0, tol=self.tol, maxiter=self.maxiter
        )
        return solution

__call__(A, b, x0=None)

Solve A*x = b using GMRES.

Parameters:

Name Type Description Default
A Union[LinearMap, Array]

Dense matrix or linear operator with signature x -> A*x

required
b Array

Right-hand side vector

required
x0 Optional[Array]

Initial guess vector

None

Returns:

Type Description
Array

Approximate solution x

Source code in src/jax_fno/integrate/linsolvers/krylov.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
def __call__(
    self, 
    A: Union[LinearMap, Array], 
    b: Array,
    x0: Optional[Array] = None
) -> Array:
    """
    Solve A*x = b using GMRES.

    Args:
        A: Dense matrix or linear operator with signature x -> A*x
        b: Right-hand side vector
        x0: Initial guess vector

    Returns:
        Approximate solution x
    """
    solution, info = jax_sparse.gmres(
        A, b, x0=x0, tol=self.tol, maxiter=self.maxiter
    )
    return solution

jax_fno.integrate.CG

Bases: Module

Conjugate Gradients (CG).

Dispatches to jax.scipy.sparse.linalg.cg. Only suitable for symmetric and positive-definite systems.

Implements: LinearSolverProtocol

Attributes:

Name Type Description
tol

Convergence tolerance for residual norm

maxiter

Maximum number of iterations

Source code in src/jax_fno/integrate/linsolvers/krylov.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
class CG(nnx.Module):
    """
    Conjugate Gradients (CG).

    Dispatches to `jax.scipy.sparse.linalg.cg`.
    Only suitable for symmetric and positive-definite systems.

    Implements: LinearSolverProtocol

    Attributes:
        tol: Convergence tolerance for residual norm
        maxiter: Maximum number of iterations
    """

    def __init__(self, tol: float = 1e-6, maxiter: int = 100):
        self.tol = tol
        self.maxiter = maxiter

    def __call__(
        self, 
        A: Union[LinearMap, Array], 
        b: Array,
        x0: Optional[Array] = None
    ) -> Array:
        """
        Solve A*x = b.

        Args:
            A: Dense matrix or linear operator with signature x -> A*x
            b: Right-hand side vector
            x0: Initial guess vector

        Returns:
            Approximate solution x
        """
        solution, info = jax_sparse.cg(
            A, b, x0=x0, tol=self.tol, maxiter=self.maxiter
        )
        return solution

__call__(A, b, x0=None)

Solve A*x = b.

Parameters:

Name Type Description Default
A Union[LinearMap, Array]

Dense matrix or linear operator with signature x -> A*x

required
b Array

Right-hand side vector

required
x0 Optional[Array]

Initial guess vector

None

Returns:

Type Description
Array

Approximate solution x

Source code in src/jax_fno/integrate/linsolvers/krylov.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def __call__(
    self, 
    A: Union[LinearMap, Array], 
    b: Array,
    x0: Optional[Array] = None
) -> Array:
    """
    Solve A*x = b.

    Args:
        A: Dense matrix or linear operator with signature x -> A*x
        b: Right-hand side vector
        x0: Initial guess vector

    Returns:
        Approximate solution x
    """
    solution, info = jax_sparse.cg(
        A, b, x0=x0, tol=self.tol, maxiter=self.maxiter
    )
    return solution

jax_fno.integrate.BiCGStab

Bases: Module

Stabilised Biconjugate Gradients (BiCGStab).

Dispatches to jax.scipy.sparse.linalg.bicgstab. Suitable for non-symmetric systems.

Implements: LinearSolverProtocol

Attributes:

Name Type Description
tol

Convergence tolerance for residual norm

maxiter

Maximum number of iterations

Source code in src/jax_fno/integrate/linsolvers/krylov.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class BiCGStab(nnx.Module):
    """
    Stabilised Biconjugate Gradients (BiCGStab).

    Dispatches to `jax.scipy.sparse.linalg.bicgstab`.
    Suitable for non-symmetric systems.

    Implements: LinearSolverProtocol

    Attributes:
        tol: Convergence tolerance for residual norm
        maxiter: Maximum number of iterations
    """

    def __init__(self, tol: float = 1e-6, maxiter: int = 100):
        self.tol = tol
        self.maxiter = maxiter

    def __call__(
        self, 
        A: Union[LinearMap, Array], 
        b: Array,
        x0: Optional[Array] = None
    ) -> Array:
        """
        Solve A*x = b.

        Args:
            A: Dense matrix or linear operator with signature x -> A*x
            b: Right-hand side vector
            x0: Initial guess vector

        Returns:
            Approximate solution x
        """
        solution, info = jax_sparse.bicgstab(
            A, b, x0=x0, tol=self.tol, maxiter=self.maxiter
        )
        return solution

__call__(A, b, x0=None)

Solve A*x = b.

Parameters:

Name Type Description Default
A Union[LinearMap, Array]

Dense matrix or linear operator with signature x -> A*x

required
b Array

Right-hand side vector

required
x0 Optional[Array]

Initial guess vector

None

Returns:

Type Description
Array

Approximate solution x

Source code in src/jax_fno/integrate/linsolvers/krylov.py
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def __call__(
    self, 
    A: Union[LinearMap, Array], 
    b: Array,
    x0: Optional[Array] = None
) -> Array:
    """
    Solve A*x = b.

    Args:
        A: Dense matrix or linear operator with signature x -> A*x
        b: Right-hand side vector
        x0: Initial guess vector

    Returns:
        Approximate solution x
    """
    solution, info = jax_sparse.bicgstab(
        A, b, x0=x0, tol=self.tol, maxiter=self.maxiter
    )
    return solution

jax_fno.integrate.DirectDense

Bases: Module

Direct solver for dense linear systems.

Dispatches to jax.numpy.linalg.solve. Only suitable for small systems where the Jacobian is provided explicitly.

Source code in src/jax_fno/integrate/linsolvers/direct.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class DirectDense(nnx.Module):
    """
    Direct solver for dense linear systems.

    Dispatches to `jax.numpy.linalg.solve`.
    Only suitable for small systems where the Jacobian is provided explicitly.
    """

    def __call__(
        self, 
        A: Union[LinearMap, Array], 
        b: Array,
        x0: Optional[Array] = None
    ) -> Array:
        """
        Solve A*x = b.

        Args:
            A: Dense matrix
            b: Right-hand side vector
            x0: Ignored (kept for interface compatibility). Can be None.

        Returns:
            Solution x

        Raises:
            TypeError: If A is a callable (linear operator) instead of a matrix
        """
        if callable(A):
            raise TypeError(
                "DirectSolve requires a dense matrix, not a linear operator. "
                "Please provide the Jacobian as a dense matrix (jac_fn), "
                "or use an iterative solver (GMRES, CG, BiCGStab)."
            )

        return jnp.linalg.solve(A, b)

__call__(A, b, x0=None)

Solve A*x = b.

Parameters:

Name Type Description Default
A Union[LinearMap, Array]

Dense matrix

required
b Array

Right-hand side vector

required
x0 Optional[Array]

Ignored (kept for interface compatibility). Can be None.

None

Returns:

Type Description
Array

Solution x

Raises:

Type Description
TypeError

If A is a callable (linear operator) instead of a matrix

Source code in src/jax_fno/integrate/linsolvers/direct.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
def __call__(
    self, 
    A: Union[LinearMap, Array], 
    b: Array,
    x0: Optional[Array] = None
) -> Array:
    """
    Solve A*x = b.

    Args:
        A: Dense matrix
        b: Right-hand side vector
        x0: Ignored (kept for interface compatibility). Can be None.

    Returns:
        Solution x

    Raises:
        TypeError: If A is a callable (linear operator) instead of a matrix
    """
    if callable(A):
        raise TypeError(
            "DirectSolve requires a dense matrix, not a linear operator. "
            "Please provide the Jacobian as a dense matrix (jac_fn), "
            "or use an iterative solver (GMRES, CG, BiCGStab)."
        )

    return jnp.linalg.solve(A, b)

:: jax_fno.integrate.Spectral

jax_fno.integrate.LinearSolverProtocol

Bases: Protocol

Protocol for linear solvers.

Defines the interface for solving linear systems of the form A*x = b. Any class implementing a call() method with this signature can be used as a linear solver in implicit integration schemes.

Source code in src/jax_fno/integrate/linsolvers/protocol.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
@runtime_checkable
class LinearSolverProtocol(Protocol):
    """
    Protocol for linear solvers.

    Defines the interface for solving linear systems of the form A*x = b.
    Any class implementing a __call__() method with this signature can be used
    as a linear solver in implicit integration schemes.
    """

    def __call__(
        self, 
        A: Union[LinearMap, Array], 
        b: Array,
        x0: Optional[Array] = None
    ) -> Array:
        """
        Solve the linear system A*x = b.

        Args:
            A: Dense matrix or linear operator with signature x -> A*x
            b: Right-hand side vector
            x0: Initial guess vector

        Returns:
            Solution vector x such that A*x ≈ b
        """
        ...

__call__(A, b, x0=None)

Solve the linear system A*x = b.

Parameters:

Name Type Description Default
A Union[LinearMap, Array]

Dense matrix or linear operator with signature x -> A*x

required
b Array

Right-hand side vector

required
x0 Optional[Array]

Initial guess vector

None

Returns:

Type Description
Array

Solution vector x such that A*x ≈ b

Source code in src/jax_fno/integrate/linsolvers/protocol.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def __call__(
    self, 
    A: Union[LinearMap, Array], 
    b: Array,
    x0: Optional[Array] = None
) -> Array:
    """
    Solve the linear system A*x = b.

    Args:
        A: Dense matrix or linear operator with signature x -> A*x
        b: Right-hand side vector
        x0: Initial guess vector

    Returns:
        Solution vector x such that A*x ≈ b
    """
    ...