Skip to content

Learning (jax_fno.learn)

jax_fno.learn provides Fourier neural operator (FNO) models and layers.

Models

jax_fno.learn.FNO1D

Bases: Module

1D Fourier Neural Operator as described by Li et al. (2020).

Parameters:

Name Type Description Default
key Array

Key for pseudo-random initialisation

required
input_dim int

Number of input features

required
output_dim int

Number of output features

required
width int

Dimension of Fourier layers

64
n_modes int

Number of modes to keep in each Fourier layer

16
n_layers int

Number of Fourier layers

4
projection_hidden Optional[int]

Dimension of projection layer (default: width)

None

Input/Output: Input: (batch, input_dim, n_points) Output: (batch, output_dim, n_points)

Source code in src/jax_fno/learn/fno1d.py
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
class FNO1D(nnx.Module):
    """
    1D Fourier Neural Operator as described by Li et al. (2020).

    Args:
        key: Key for pseudo-random initialisation
        input_dim: Number of input features
        output_dim: Number of output features
        width: Dimension of Fourier layers
        n_modes: Number of modes to keep in each Fourier layer
        n_layers: Number of Fourier layers
        projection_hidden: Dimension of projection layer (default: width)

    Input/Output:
        Input: (batch, input_dim, n_points)
        Output: (batch, output_dim, n_points)
    """

    def __init__(
        self,
        key: jax.Array,
        input_dim: int,
        output_dim: int,
        width: int = 64,
        n_modes: int = 16,
        n_layers: int = 4,
        projection_hidden: Optional[int] = None,
    ):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.width = width
        self.n_modes = n_modes
        self.n_layers = n_layers
        self.activation = jax.nn.gelu

        if projection_hidden is None:
            projection_hidden = width

        self.projection_hidden = projection_hidden

        # Lifting FCNN: input_dim -> width
        key, subkey = jax.random.split(key, 2)
        self.lift = Linear1D(subkey, input_dim, width)

        # Fourier layers
        self.fourier_layers = nnx.List([])
        for i in range(n_layers):
            key, subkey = jax.random.split(key)
            layer = FourierLayer1D(
                key=subkey,
                channels_in=width,
                channels_out=width,
                n_modes=n_modes,
            )
            self.fourier_layers.append(layer)

        # Projection FCNN: width -> hidden -> output_dim
        key, subkey1, subkey2 = jax.random.split(key, 3)
        self.proj1 = Linear1D(subkey1, width, projection_hidden)
        self.proj2 = Linear1D(subkey2, projection_hidden, output_dim)

    def __call__(self, x: jax.Array) -> jax.Array:
        """
        Forward pass through the FNO.

        Args:
            x: Input tensor of shape (batch, input_dim, n_points)

        Returns:
            Output tensor of shape (batch, output_dim, n_points)
        """
        # (batch, input_dim, n_points) -> (batch, width, n_points)
        x = self.lift(x)

        # (batch, width, n_points) -> (batch, width, n_points)
        for fourier_layer in self.fourier_layers:
            x = self.activation(fourier_layer(x))

        # (batch, width, n_points) -> (batch, output_dim, n_points)
        x = self.proj1(x)
        x = self.activation(x)
        x = self.proj2(x)

        return x

__call__(x)

Forward pass through the FNO.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (batch, input_dim, n_points)

required

Returns:

Type Description
Array

Output tensor of shape (batch, output_dim, n_points)

Source code in src/jax_fno/learn/fno1d.py
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def __call__(self, x: jax.Array) -> jax.Array:
    """
    Forward pass through the FNO.

    Args:
        x: Input tensor of shape (batch, input_dim, n_points)

    Returns:
        Output tensor of shape (batch, output_dim, n_points)
    """
    # (batch, input_dim, n_points) -> (batch, width, n_points)
    x = self.lift(x)

    # (batch, width, n_points) -> (batch, width, n_points)
    for fourier_layer in self.fourier_layers:
        x = self.activation(fourier_layer(x))

    # (batch, width, n_points) -> (batch, output_dim, n_points)
    x = self.proj1(x)
    x = self.activation(x)
    x = self.proj2(x)

    return x

Layers

jax_fno.learn.FourierLayer1D

Bases: Module

A Fourier layer for a 1D Fourier Neural Operator as described by Li et al. "Fourier Neural Operator for Parametric Partial Differential Equations" (2020).

The layer computes $$ x_{\mathrm{out}} = W * x_{\mathrm{in}} + F^{-1}[R * F(x_{\mathrm{in}})], $$ where \(W\) is a trainable linear transformation, \(F\) is a fast Fourier transform, \(R\) is a tensor of trainable complex weights for Fourier modes, and \(*\) denotes element-wise multiplication.

Fourier modes with k > k_max are filtered out during the convolution \(F^{-1}[R * F(v)]\).

Parameters:

Name Type Description Default
key Array

Key for pseudo-random initialisation

key(0)
channels_in int

Number of channels in input

64
channels_out int

Number of channels in output

64
n_modes int

Maximum number of Fourier modes

16
Source code in src/jax_fno/learn/fno1d.py
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class FourierLayer1D(nnx.Module):
    """
    A Fourier layer for a 1D Fourier Neural Operator as described by
    Li et al. "Fourier Neural Operator for Parametric Partial 
    Differential Equations" (2020).

    The layer computes
    $$ 
    x_{\\mathrm{out}} = W * x_{\\mathrm{in}} + F^{-1}[R * F(x_{\\mathrm{in}})], 
    $$
    where $W$ is a trainable linear transformation, $F$ is a fast Fourier 
    transform, $R$ is a tensor of trainable complex weights for Fourier modes,
    and $*$ denotes element-wise multiplication.

    Fourier modes with k > k_max are filtered out during the
    convolution $F^{-1}[R * F(v)]$.

    Args:
        key: Key for pseudo-random initialisation
        channels_in: Number of channels in input
        channels_out: Number of channels in output
        n_modes: Maximum number of Fourier modes
    """

    def __init__(
        self,
        key: jax.Array = jax.random.key(0),
        channels_in: int = 64,
        channels_out: int = 64,
        n_modes: int = 16,
    ):
        self.channels_in = channels_in
        self.channels_out = channels_out
        self.n_modes = n_modes

        key1, key2 = jax.random.split(key)
        self.spectral = SpectralConv1D(
            key1, channels_in, channels_out, n_modes
        )
        self.linear = Linear1D(key2, channels_in, channels_out)

    def __call__(self, x: jax.Array) -> jax.Array:
        x1 = self.spectral(x)
        x2 = self.linear(x)
        return x1 + x2

jax_fno.learn.SpectralConv1D

Bases: Module

A spectral convolution layer used in the 1D Fourier Neural Operator as described by Li et al. "Fourier Neural Operator for Parametric Partial Differential Equations" (2020).

Parameters:

Name Type Description Default
key Array

Key for pseudo-random initialisation

required
channels_in int

Number of channels in input

required
channels_out int

Number of channels in output

required
n_modes int

Maximum number of Fourier modes

required
Source code in src/jax_fno/learn/fno1d.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class SpectralConv1D(nnx.Module):
    """
    A spectral convolution layer used in the 1D Fourier Neural Operator as described by
    Li et al. "Fourier Neural Operator for Parametric Partial Differential Equations" (2020).

    Args:
        key: Key for pseudo-random initialisation
        channels_in: Number of channels in input
        channels_out: Number of channels in output
        n_modes: Maximum number of Fourier modes
    """

    def __init__(
        self, key: jax.Array, channels_in: int, channels_out: int, n_modes: int
    ):
        self.channels_in = channels_in
        self.channels_out = channels_out
        self.n_modes = n_modes

        # Initialise complex weights
        key1, key2 = jax.random.split(key)
        initializer = jax.nn.initializers.glorot_uniform(in_axis=1, out_axis=0)
        real_part = initializer(
            key1, (self.channels_out, self.channels_in, self.n_modes)
        )
        imag_part = initializer(
            key2, (self.channels_out, self.channels_in, self.n_modes)
        )
        self.weights = nnx.Param((real_part + 1j * imag_part))

    def __call__(self, x: jax.Array) -> jax.Array:
        """
        Perform 1D spectral convolution using FFT.

        Args:
            x: Input tensor of shape (batch_size, channels, spatial_points)

        Returns:
            Output tensor of same shape after spectral convolution
        """
        n = x.shape[-1]

        # Transform to spectral space
        Fx = jnp.fft.rfft(
            x, n=n, axis=2, norm="ortho"
        )  # Shape: (batch, d_v, n//2+1)

        # Filter out high-frequency modes
        k_max = min(self.n_modes, Fx.shape[2])
        Fx_filtered = Fx[:, :, :k_max]  # Shape: (batch, d_v, k_max)

        # Multiply with weights
        # Einsum indices:
        #   'o': output channel
        #   'i': input channel
        #   'm': mode index
        #   'b': batch index
        RFv_filtered = jnp.einsum("oim,bim->bom", self.weights, Fx_filtered)

        # Pad result to get correct shape (batch, d_v, n//2+1)
        padding = (n // 2 + 1) - k_max
        RFv = jnp.pad(RFv_filtered, ((0, 0), (0, 0), (0, padding)))

        # Transform back to physical space
        out = jnp.fft.irfft(RFv, n=n, axis=2, norm="ortho")

        return out

__call__(x)

Perform 1D spectral convolution using FFT.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (batch_size, channels, spatial_points)

required

Returns:

Type Description
Array

Output tensor of same shape after spectral convolution

Source code in src/jax_fno/learn/fno1d.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __call__(self, x: jax.Array) -> jax.Array:
    """
    Perform 1D spectral convolution using FFT.

    Args:
        x: Input tensor of shape (batch_size, channels, spatial_points)

    Returns:
        Output tensor of same shape after spectral convolution
    """
    n = x.shape[-1]

    # Transform to spectral space
    Fx = jnp.fft.rfft(
        x, n=n, axis=2, norm="ortho"
    )  # Shape: (batch, d_v, n//2+1)

    # Filter out high-frequency modes
    k_max = min(self.n_modes, Fx.shape[2])
    Fx_filtered = Fx[:, :, :k_max]  # Shape: (batch, d_v, k_max)

    # Multiply with weights
    # Einsum indices:
    #   'o': output channel
    #   'i': input channel
    #   'm': mode index
    #   'b': batch index
    RFv_filtered = jnp.einsum("oim,bim->bom", self.weights, Fx_filtered)

    # Pad result to get correct shape (batch, d_v, n//2+1)
    padding = (n // 2 + 1) - k_max
    RFv = jnp.pad(RFv_filtered, ((0, 0), (0, 0), (0, padding)))

    # Transform back to physical space
    out = jnp.fft.irfft(RFv, n=n, axis=2, norm="ortho")

    return out