Skip to content

API reference

norax.models

norax.FNO

FNO(
    key: PRNGKeyArray,
    channels_in: int,
    channels_out: int,
    n_modes: tuple[int, ...],
    width: int = 64,
    depth: int = 4,
    activation: Callable = gelu,
    dtype: dtype = result_type(float),
    lift: Optional[Callable] = None,
    project: Optional[Callable] = None,
)

Bases: Module

Fourier neural operator.

The input is expected to have shape (*coords, channels_in).

Reference

Li et al. "Fourier Neural Operator for Parametric Partial Differential Equations" (2020).

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key for parameter initialisation

required
channels_in int

Number of input channels

required
channels_out int

Number of output channels

required
n_modes tuple[int, ...]

Number of modes to retain per axis. For non-last axes, n_modes[i] modes are kept at each end of the spectrum (positive and negative). For the last axis, n_modes[-1] modes are kept from the one-sided RFFT spectrum.

required
width int

Hidden channel width

64
depth int

Number of Fourier layers

4
activation Callable

Non-linear activation function

gelu
dtype dtype

Floating-point dtype of the parameters

result_type(float)
lift Optional[Callable]

Optional custom lifting layer. Must be callable. Default: MLP with depth=2 and width=2*width.

None
project Optional[Callable]

Optional custom projection layer. Must be callable. Default: MLP with depth=2 and width=2*width.

None

__call__

__call__(
    x: Float[Array, "*coords channels_in"],
) -> Float[Array, "*coords channels_out"]

Perform a forward pass.

Parameters:

Name Type Description Default
x Float[Array, '*coords channels_in']

Input tensor of shape (*coords, channels_in)

required

Returns:

Type Description
Float[Array, '*coords channels_out']

Output tensor of shape (*coords, channels_out)

norax.MLP

MLP(
    key: PRNGKeyArray,
    input_dim: int,
    output_dim: int,
    depth: int,
    width: int,
    activation: Callable = gelu,
    dtype: dtype = result_type(float),
)

Bases: Module

Multi-layer perceptron.

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key for parameter initialisation

required
input_dim int

Dimensionality of the input features

required
output_dim int

Dimensionality of the output

required
depth int

Number of layers (including output layer)

required
width int

Number of dimensions in the hidden layers

required
activation Callable

Non-linear activation applied between hidden layers. Not applied after the final layer.

gelu
dtype dtype

Floating-point dtype of the parameters

result_type(float)

__call__

__call__(
    x: Float[Array, "... input_dim"],
) -> Float[Array, "... output_dim"]

Perform a forward pass.

Parameters:

Name Type Description Default
x Float[Array, '... input_dim']

Input array of shape (..., input_dim).

required

Returns:

Type Description
Float[Array, '... output_dim']

Output array of shape (..., output_dim).

norax.layers

norax.layers.Fourier

Fourier(
    key: PRNGKeyArray,
    channels_in: int,
    channels_out: int,
    n_modes: tuple[int, ...],
    activation: Callable = gelu,
    dtype: dtype = result_type(float),
)

Bases: Module

A Fourier layer for a Fourier neural operator.

Reference

Li et al. "Fourier Neural Operator for Parametric Partial Differential Equations" (2020).

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key for parameter initialisation

required
channels_in int

Number of input channels

required
channels_out int

Number of output channels

required
n_modes tuple[int, ...]

Tuple of maximum number of modes per coordinate axis

required
activation Callable

Non-linear activation function

gelu
dtype dtype

Floating-point dtype of the parameters

result_type(float)

__call__

__call__(
    x: Float[Array, "*coords channels_in"],
) -> Float[Array, "*coords channels_out"]

Apply the Fourier layer.

The layer computes v_{t+1} = sigma(W v_t + F^{-1}[R * F(v_t)]), where W is a trainable pointwise linear transformation, F is the fast Fourier transform, R contains trainable complex weights for the retained Fourier modes, sigma is a non-linear activation function, and * denotes element-wise multiplication.

Parameters:

Name Type Description Default
x Float[Array, '*coords channels_in']

Input tensor of shape (*coords, channels_in)

required

Returns:

Type Description
Float[Array, '*coords channels_out']

Array with shape (*coords, channels_out)

norax.layers.SpectralConv

SpectralConv(
    key: PRNGKeyArray,
    channels_in: int,
    channels_out: int,
    n_modes: tuple[int, ...],
    init: Callable = complex_glorot,
    dtype: dtype = result_type(float),
)

Bases: Module

A spectral convolution layer used in the Fourier neural operator.

Reference

Li et al. "Fourier Neural Operator for Parametric Partial Differential Equations" (2020).

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key for parameter initialisation

required
channels_in int

Number of input channels

required
channels_out int

Number of output channels

required
n_modes tuple[int, ...]

Number of modes to retain per axis. For non-last axes, n_modes[i] modes are kept at each end of the spectrum (positive and negative), so the stored weight size is 2 * n_modes[i]. For the last axis the one-sided RFFT spectrum is used, so n_modes[-1] weights are stored.

required
init Callable

Complex weight initialiser used to draw the initial real and imaginary components

complex_glorot
dtype dtype

Floating-point dtype of the parameters

result_type(float)

__call__

__call__(
    x: Float[Array, "*coords channels_in"],
) -> Float[Array, "*coords channels_out"]

Perform a spectral convolution using an FFT.

Parameters:

Name Type Description Default
x Float[Array, '*coords channels_in']

Input array with shape (*coords, channels_in)

required

Returns:

Type Description
Float[Array, '*coords channels_out']

Array with shape (*coords, channels_out)

norax.layers.Linear

Linear(
    key: PRNGKeyArray,
    input_dim: int,
    output_dim: int,
    init: Callable = glorot_uniform(),
    dtype: dtype = result_type(float),
)

Bases: Module

A pointwise linear transformation applied along the last axis.

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key for parameter initialisation

required
input_dim int

Number of input channels

required
output_dim int

Number of output channels

required
init Callable

Weight initialiser

glorot_uniform()
dtype dtype

Floating-point dtype of the parameters

result_type(float)

__call__

__call__(
    x: Float[Array, "... input_dim"],
) -> Float[Array, "... output_dim"]

Apply a linear transformation along the last axis.

Parameters:

Name Type Description Default
x Float[Array, '... input_dim']

Array with shape (..., channels_in)

required

Returns:

Type Description
Float[Array, '... output_dim']

Array with shape (..., channels_out)

Utilities

norax.initialisers.complex_glorot

complex_glorot(
    key: PRNGKeyArray,
    shape: tuple[int, ...],
    in_axis: int = 1,
    out_axis: int = 0,
    dtype: dtype = result_type(float),
) -> Complex[Array, "*shape"]

Glorot-scaled complex weight initialisation.

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key

required
shape tuple[int, ...]

Weight tensor shape

required
in_axis int

Axis corresponding to input channels

1
out_axis int

Axis corresponding to output channels

0
dtype dtype

Real dtype for the components (float32 or float64)

result_type(float)

Returns:

Type Description
Complex[Array, '*shape']

Complex array of the given shape (complex64 or complex128)