Skip to content

norax

Neural operators in JAX.

norax provides implementations of neural operators built on top of JAX and Equinox.

Currently only Fourier neural operators are provided.

Installation

uv is recommended for installation.

Using uv

With SSH:

uv init my-project
cd my-project
uv add git+ssh://git@github.com/christianfenton/norax.git

With HTTPS:

uv add git+https://github.com/christianfenton/norax.git

Using pip

With SSH:

pip install git+ssh://git@github.com/christianfenton/norax.git

With HTTPS:

pip install git+https://github.com/christianfenton/norax.git

Quick start

import jax
import jax.numpy as jnp
from norax.models import FNO

key = jax.random.key(0)

# 1D Fourier neural operator
model = FNO(
    key,
    channels_in=2,   # e.g. [grid coordinate, initial condition]
    channels_out=1,
    n_modes=(16,),   # Fourier modes to retain along the spatial axis
    width=64,
    depth=4,
)

# Forward pass over a batch using vmap
x = jnp.ones((10, 256, 2))  # (batch, resolution, channels_in)
y = jax.vmap(model)(x)       # (batch, resolution, channels_out)