norax¶
Neural operators in JAX.
norax provides implementations of neural operators built on top of
JAX and
Equinox.
Currently only Fourier neural operators are provided.
Installation¶
uv is recommended for installation.
Using uv¶
With SSH:
uv init my-project
cd my-project
uv add git+ssh://git@github.com/christianfenton/norax.git
With HTTPS:
uv add git+https://github.com/christianfenton/norax.git
Using pip¶
With SSH:
pip install git+ssh://git@github.com/christianfenton/norax.git
With HTTPS:
pip install git+https://github.com/christianfenton/norax.git
Quick start¶
import jax
import jax.numpy as jnp
from norax.models import FNO
key = jax.random.key(0)
# 1D Fourier neural operator
model = FNO(
key,
channels_in=2, # e.g. [grid coordinate, initial condition]
channels_out=1,
n_modes=(16,), # Fourier modes to retain along the spatial axis
width=64,
depth=4,
)
# Forward pass over a batch using vmap
x = jnp.ones((10, 256, 2)) # (batch, resolution, channels_in)
y = jax.vmap(model)(x) # (batch, resolution, channels_out)
Related projects¶
- neuraloperator: PyTorch implementations of neural operators