Skip to content

API reference

norax.models

norax.FNO

FNO(
    key: PRNGKeyArray,
    channels_in: int,
    channels_out: int,
    n_modes: tuple[int, ...],
    width: int = 64,
    depth: int = 4,
    activation: Callable = gelu,
    dtype: dtype = result_type(float),
    lift: Optional[Callable] = None,
    project: Optional[Callable] = None,
)

Bases: Module

Fourier neural operator.

The input is expected to have shape (*coords, channels_in).

Reference

Li et al. "Fourier Neural Operator for Parametric Partial Differential Equations" (2020).

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key for parameter initialisation

required
channels_in int

Number of input channels

required
channels_out int

Number of output channels

required
n_modes tuple[int, ...]

Tuple of maximum number of modes per coordinate axis

required
width int

Hidden channel width

64
depth int

Number of Fourier layers

4
activation Callable

Non-linear activation function

gelu
dtype dtype

Floating-point dtype of the parameters

result_type(float)
lift Optional[Callable]

Optional custom lifting layer. Must be callable. Default: MLP with depth=2 and width=2*width.

None
project Optional[Callable]

Optional custom projection layer. Must be callable. Default: MLP with depth=2 and width=2*width.

None

__call__

__call__(
    x: Float[Array, "*coords channels_in"],
) -> Float[Array, "*coords channels_out"]

Perform a forward pass.

Parameters:

Name Type Description Default
x Float[Array, '*coords channels_in']

Input tensor of shape (*coords, channels_in)

required

Returns:

Type Description
Float[Array, '*coords channels_out']

Output tensor of shape (*coords, channels_out)

norax.MLP

MLP(
    key: PRNGKeyArray,
    input_dim: int,
    output_dim: int,
    depth: int,
    width: int,
    activation: Callable = gelu,
    dtype: dtype = result_type(float),
)

Bases: Module

Multi-layer perceptron.

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key for parameter initialisation

required
input_dim int

Dimensionality of the input features

required
output_dim int

Dimensionality of the output

required
depth int

Number of layers (including output layer)

required
width int

Number of dimensions in the hidden layers

required
activation Callable

Non-linear activation applied between hidden layers. Not applied after the final layer.

gelu
dtype dtype

Floating-point dtype of the parameters

result_type(float)

__call__

__call__(
    x: Float[Array, "... input_dim"],
) -> Float[Array, "... output_dim"]

Perform a forward pass.

Parameters:

Name Type Description Default
x Float[Array, '... input_dim']

Input array of shape (..., input_dim).

required

Returns:

Type Description
Float[Array, '... output_dim']

Output array of shape (..., output_dim).

norax.layers

norax.layers.Fourier

Fourier(
    key: PRNGKeyArray,
    channels_in: int,
    channels_out: int,
    n_modes: tuple[int, ...],
    activation: Callable = gelu,
    dtype: dtype = result_type(float),
)

Bases: Module

A Fourier layer for a Fourier neural operator.

Reference

Li et al. "Fourier Neural Operator for Parametric Partial Differential Equations" (2020).

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key for parameter initialisation

required
channels_in int

Number of input channels

required
channels_out int

Number of output channels

required
n_modes tuple[int, ...]

Tuple of maximum number of modes per coordinate axis

required
activation Callable

Non-linear activation function

gelu
dtype dtype

Floating-point dtype of the parameters

result_type(float)

__call__

__call__(
    x: Float[Array, "*coords channels_in"],
) -> Float[Array, "*coords channels_out"]

Apply the Fourier layer.

The layer computes v_{t+1} = sigma(W v_t + F^{-1}[R * F(v_t)]), where W is a trainable pointwise linear transformation, F is the fast Fourier transform, R contains trainable complex weights for the retained Fourier modes, sigma is a non-linear activation function, and * denotes element-wise multiplication.

Parameters:

Name Type Description Default
x Float[Array, '*coords channels_in']

Input tensor of shape (*coords, channels_in)

required

Returns:

Type Description
Float[Array, '*coords channels_out']

Array with shape (*coords, channels_out)

norax.layers.SpectralConv

SpectralConv(
    key: PRNGKeyArray,
    channels_in: int,
    channels_out: int,
    n_modes: tuple[int, ...],
    init: Callable = complex_glorot,
    dtype: dtype = result_type(float),
)

Bases: Module

A spectral convolution layer used in the Fourier neural operator.

Reference

Li et al. "Fourier Neural Operator for Parametric Partial Differential Equations" (2020).

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key for parameter initialisation

required
channels_in int

Number of input channels

required
channels_out int

Number of output channels

required
n_modes tuple[int, ...]

Maximum number of modes to keep per coordinate axis

required
init Callable

Complex weight initialiser used to draw the initial real and imaginary components

complex_glorot
dtype dtype

Floating-point dtype of the parameters

result_type(float)

__call__

__call__(
    x: Float[Array, "*coords channels_in"],
) -> Float[Array, "*coords channels_out"]

Perform a spectral convolution using an FFT.

Parameters:

Name Type Description Default
x Float[Array, '*coords channels_in']

Input array with shape (*coords, channels_in)

required

Returns:

Type Description
Float[Array, '*coords channels_out']

Array with shape (*coords, channels_out)

norax.layers.Linear

Linear(
    key: PRNGKeyArray,
    input_dim: int,
    output_dim: int,
    init: Callable = glorot_uniform(),
    dtype: dtype = result_type(float),
)

Bases: Module

A pointwise linear transformation applied along the last axis.

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key for parameter initialisation

required
input_dim int

Number of input channels

required
output_dim int

Number of output channels

required
init Callable

Weight initialiser

glorot_uniform()
dtype dtype

Floating-point dtype of the parameters

result_type(float)

__call__

__call__(
    x: Float[Array, "... input_dim"],
) -> Float[Array, "... output_dim"]

Apply a linear transformation along the last axis.

Parameters:

Name Type Description Default
x Float[Array, '... input_dim']

Array with shape (..., channels_in)

required

Returns:

Type Description
Float[Array, '... output_dim']

Array with shape (..., channels_out)

Utilities

norax.data.DataLoader

DataLoader(
    dataset: dict[str, Shaped[Array, " n_samples *shape"]],
    batch_size: int,
    shuffle: bool = True,
    seed: int = 0,
)

Bases: Iterator

An iterator that yields batches from a dataset.

Parameters:

Name Type Description Default
dataset dict[str, Shaped[Array, ' n_samples *shape']]

Dictionary mapping string keys to arrays. All arrays must share the same size along their first axis.

required
batch_size int

Number of samples per batch. The final batch of an epoch may be smaller if the dataset size is not divisible.

required
shuffle bool

Optionally shuffle the samples at the start of each epoch

True
seed int

Integer seed for the random number generator used to shuffle

0

reset

reset() -> None

Reset the iterator to the beginning of the dataset.

__next__

__next__() -> dict[str, Shaped[Array, ' batch *shape']]

Return the next batch.

Returns:

Type Description
dict[str, Shaped[Array, ' batch *shape']]

Dictionary with the same keys as the dataset

Raises:

Type Description
StopIteration

When all samples have been yielded

__len__

__len__() -> int

Return the number of batches.

norax.initialisers.complex_glorot

complex_glorot(
    key: PRNGKeyArray,
    shape: tuple[int, ...],
    in_axis: int = 1,
    out_axis: int = 0,
    dtype: dtype = result_type(float),
) -> Complex[Array, "*shape"]

Glorot-scaled complex weight initialisation.

Parameters:

Name Type Description Default
key PRNGKeyArray

PRNG key

required
shape tuple[int, ...]

Weight tensor shape

required
in_axis int

Axis corresponding to input channels

1
out_axis int

Axis corresponding to output channels

0
dtype dtype

Real dtype for the components (float32 or float64)

result_type(float)

Returns:

Type Description
Complex[Array, '*shape']

Complex array of the given shape (complex64 or complex128)