Skip to content

norax

Neural operators in JAX.

norax provides implementations of neural operators built on top of JAX and Equinox.

Currently only Fourier neural operators are provided.

Installation

pip install norax

Or with uv:

uv add norax

Quick start

import jax
import jax.numpy as jnp
import norax as nrx

key = jax.random.key(0)

model = nrx.FNO(
    key,
    channels_in=2,
    channels_out=1,
    n_modes=(16,),  # Number of Fourier modes to retain along the spatial axis
    width=64,
    depth=4,
)

# Forward pass over a batch using vmap
x = jnp.ones((10, 256, 2))  # (batch, *grid_shape, channels_in)
y = jax.vmap(model)(x)      # (batch, *grid_shape, channels_out)

References

Li, Zongyi, et al. "Fourier neural operator for parametric partial differential equations." arXiv preprint arXiv:2010.08895 (2020).