API reference¶
norax.models¶
norax.FNO ¶
FNO(
key: PRNGKeyArray,
channels_in: int,
channels_out: int,
n_modes: tuple[int, ...],
width: int = 64,
depth: int = 4,
activation: Callable = gelu,
dtype: dtype = result_type(float),
lift: Optional[Callable] = None,
project: Optional[Callable] = None,
)
Bases: Module
Fourier neural operator.
The input is expected to have shape (*coords, channels_in).
Reference
Li et al. "Fourier Neural Operator for Parametric Partial Differential Equations" (2020).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray
|
PRNG key for parameter initialisation |
required |
channels_in
|
int
|
Number of input channels |
required |
channels_out
|
int
|
Number of output channels |
required |
n_modes
|
tuple[int, ...]
|
Tuple of maximum number of modes per coordinate axis |
required |
width
|
int
|
Hidden channel width |
64
|
depth
|
int
|
Number of Fourier layers |
4
|
activation
|
Callable
|
Non-linear activation function |
gelu
|
dtype
|
dtype
|
Floating-point dtype of the parameters |
result_type(float)
|
lift
|
Optional[Callable]
|
Optional custom lifting layer. Must be callable. Default: MLP with depth=2 and width=2*width. |
None
|
project
|
Optional[Callable]
|
Optional custom projection layer. Must be callable. Default: MLP with depth=2 and width=2*width. |
None
|
__call__ ¶
__call__(
x: Float[Array, "*coords channels_in"],
) -> Float[Array, "*coords channels_out"]
Perform a forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, '*coords channels_in']
|
Input tensor of shape (*coords, channels_in) |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '*coords channels_out']
|
Output tensor of shape (*coords, channels_out) |
norax.MLP ¶
MLP(
key: PRNGKeyArray,
input_dim: int,
output_dim: int,
depth: int,
width: int,
activation: Callable = gelu,
dtype: dtype = result_type(float),
)
Bases: Module
Multi-layer perceptron.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray
|
PRNG key for parameter initialisation |
required |
input_dim
|
int
|
Dimensionality of the input features |
required |
output_dim
|
int
|
Dimensionality of the output |
required |
depth
|
int
|
Number of layers (including output layer) |
required |
width
|
int
|
Number of dimensions in the hidden layers |
required |
activation
|
Callable
|
Non-linear activation applied between hidden layers. Not applied after the final layer. |
gelu
|
dtype
|
dtype
|
Floating-point dtype of the parameters |
result_type(float)
|
__call__ ¶
__call__(
x: Float[Array, "... input_dim"],
) -> Float[Array, "... output_dim"]
Perform a forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, '... input_dim']
|
Input array of shape |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '... output_dim']
|
Output array of shape |
norax.layers¶
norax.layers.Fourier ¶
Fourier(
key: PRNGKeyArray,
channels_in: int,
channels_out: int,
n_modes: tuple[int, ...],
activation: Callable = gelu,
dtype: dtype = result_type(float),
)
Bases: Module
A Fourier layer for a Fourier neural operator.
Reference
Li et al. "Fourier Neural Operator for Parametric Partial Differential Equations" (2020).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray
|
PRNG key for parameter initialisation |
required |
channels_in
|
int
|
Number of input channels |
required |
channels_out
|
int
|
Number of output channels |
required |
n_modes
|
tuple[int, ...]
|
Tuple of maximum number of modes per coordinate axis |
required |
activation
|
Callable
|
Non-linear activation function |
gelu
|
dtype
|
dtype
|
Floating-point dtype of the parameters |
result_type(float)
|
__call__ ¶
__call__(
x: Float[Array, "*coords channels_in"],
) -> Float[Array, "*coords channels_out"]
Apply the Fourier layer.
The layer computes v_{t+1} = sigma(W v_t + F^{-1}[R * F(v_t)]), where W is a trainable pointwise linear transformation, F is the fast Fourier transform, R contains trainable complex weights for the retained Fourier modes, sigma is a non-linear activation function, and * denotes element-wise multiplication.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, '*coords channels_in']
|
Input tensor of shape (*coords, channels_in) |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '*coords channels_out']
|
Array with shape (*coords, channels_out) |
norax.layers.SpectralConv ¶
SpectralConv(
key: PRNGKeyArray,
channels_in: int,
channels_out: int,
n_modes: tuple[int, ...],
init: Callable = complex_glorot,
dtype: dtype = result_type(float),
)
Bases: Module
A spectral convolution layer used in the Fourier neural operator.
Reference
Li et al. "Fourier Neural Operator for Parametric Partial Differential Equations" (2020).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray
|
PRNG key for parameter initialisation |
required |
channels_in
|
int
|
Number of input channels |
required |
channels_out
|
int
|
Number of output channels |
required |
n_modes
|
tuple[int, ...]
|
Maximum number of modes to keep per coordinate axis |
required |
init
|
Callable
|
Complex weight initialiser used to draw the initial real and imaginary components |
complex_glorot
|
dtype
|
dtype
|
Floating-point dtype of the parameters |
result_type(float)
|
__call__ ¶
__call__(
x: Float[Array, "*coords channels_in"],
) -> Float[Array, "*coords channels_out"]
Perform a spectral convolution using an FFT.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, '*coords channels_in']
|
Input array with shape (*coords, channels_in) |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '*coords channels_out']
|
Array with shape (*coords, channels_out) |
norax.layers.Linear ¶
Linear(
key: PRNGKeyArray,
input_dim: int,
output_dim: int,
init: Callable = glorot_uniform(),
dtype: dtype = result_type(float),
)
Bases: Module
A pointwise linear transformation applied along the last axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray
|
PRNG key for parameter initialisation |
required |
input_dim
|
int
|
Number of input channels |
required |
output_dim
|
int
|
Number of output channels |
required |
init
|
Callable
|
Weight initialiser |
glorot_uniform()
|
dtype
|
dtype
|
Floating-point dtype of the parameters |
result_type(float)
|
__call__ ¶
__call__(
x: Float[Array, "... input_dim"],
) -> Float[Array, "... output_dim"]
Apply a linear transformation along the last axis.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Float[Array, '... input_dim']
|
Array with shape (..., channels_in) |
required |
Returns:
| Type | Description |
|---|---|
Float[Array, '... output_dim']
|
Array with shape (..., channels_out) |
Utilities¶
norax.data.DataLoader ¶
DataLoader(
dataset: dict[str, Shaped[Array, " n_samples *shape"]],
batch_size: int,
shuffle: bool = True,
seed: int = 0,
)
Bases: Iterator
An iterator that yields batches from a dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset
|
dict[str, Shaped[Array, ' n_samples *shape']]
|
Dictionary mapping string keys to arrays. All arrays must share the same size along their first axis. |
required |
batch_size
|
int
|
Number of samples per batch. The final batch of an epoch may be smaller if the dataset size is not divisible. |
required |
shuffle
|
bool
|
Optionally shuffle the samples at the start of each epoch |
True
|
seed
|
int
|
Integer seed for the random number generator used to shuffle |
0
|
__next__ ¶
__next__() -> dict[str, Shaped[Array, ' batch *shape']]
Return the next batch.
Returns:
| Type | Description |
|---|---|
dict[str, Shaped[Array, ' batch *shape']]
|
Dictionary with the same keys as the dataset |
Raises:
| Type | Description |
|---|---|
StopIteration
|
When all samples have been yielded |
norax.initialisers.complex_glorot ¶
complex_glorot(
key: PRNGKeyArray,
shape: tuple[int, ...],
in_axis: int = 1,
out_axis: int = 0,
dtype: dtype = result_type(float),
) -> Complex[Array, "*shape"]
Glorot-scaled complex weight initialisation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
PRNGKeyArray
|
PRNG key |
required |
shape
|
tuple[int, ...]
|
Weight tensor shape |
required |
in_axis
|
int
|
Axis corresponding to input channels |
1
|
out_axis
|
int
|
Axis corresponding to output channels |
0
|
dtype
|
dtype
|
Real dtype for the components (float32 or float64) |
result_type(float)
|
Returns:
| Type | Description |
|---|---|
Complex[Array, '*shape']
|
Complex array of the given shape (complex64 or complex128) |