Training an FNO on Burgers' equation in 1D¶
This tutorial demonstrates how to train a Fourier neural operator to
predict solutions of Burgers' equation in 1D. Users that want to just
run a training script without reading this tutorial
can look at examples/burgers1d/train.py on
GitHub.
For details on how the training data is generated, see Generating Training Data on Burgers' Equation in 1D.
Note: This tutorial requires users to have h5py, optax, matplotlib
installed.
Problem statement¶
Burgers' equation in 1D on a periodic domain \(x \in [0, 1)\) reads
where \(\nu > 0\) is the kinematic viscosity. The first term is a non-linear advective term and the second is a diffusive term.
The learning task is to approximate the solution operator $$ \mathcal{G}^\dagger : u_0 \mapsto u(\cdot,\, t_\text{end}), $$ mapping an initial condition \(u_0\) to the solution at time \(t_\text{end}\).
Loading the data¶
import json
import equinox as eqx
import h5py
import jax
import jax.numpy as jnp
import numpy as np
import optax
from functools import partial
from norax.data import DataLoader
from norax.models import FNO
def load_data(path):
"""Load inputs, outputs, and metadata from an HDF5 file."""
with h5py.File(path, "r") as f:
inputs = np.asarray(f["inputs"])
outputs = np.asarray(f["outputs"])
metadata = dict(f.attrs)
return inputs, outputs, metadata
inputs, outputs, metadata = load_data("myburgersdataset.h5")
print("metadata:", metadata)
print("inputs.shape:", inputs.shape)
print("outputs.shape:", outputs.shape)
metadata: {'L': 1.0, 'dt': 0.0001, 'nu': 0.02, 'num_samples': 1280, 'resolution': 256, 't_end': 1.0}
inputs.shape: (1280, 256, 2)
outputs.shape: (1280, 256, 1)
Downsample the data to speed up training:
def downsample(inputs, outputs, target_res):
"""Downsample inputs and outputs to a coarser spatial resolution.
Args:
inputs: Array of shape ``(N, res, 2)``
outputs: Array of shape ``(N, res, 1)``
target_res: Desired number of grid points after subsampling
Returns:
Downsampled inputs and outputs
"""
orig_res = inputs.shape[1]
if orig_res % target_res != 0:
raise ValueError(
f"Original resolution {orig_res} is not evenly divisible by "
f"target resolution {target_res}."
)
stride = orig_res // target_res
return inputs[:, ::stride, :], outputs[:, ::stride, :]
inputs, outputs = downsample(inputs, outputs, 64)
print("inputs.shape:", inputs.shape)
print("outputs.shape:", outputs.shape)
inputs.shape: (1280, 64, 2)
outputs.shape: (1280, 64, 1)
Split the data into train and test sets and build data loaders:
n_train = 1024
n_test = 256
inputs_jax = jnp.asarray(inputs[: n_train + n_test])
outputs_jax = jnp.asarray(outputs[: n_train + n_test])
train_data = {"input": inputs_jax[:n_train], "output": outputs_jax[:n_train]}
test_data = {"input": inputs_jax[n_train:], "output": outputs_jax[n_train:]}
batch_size = 32
train_loader = DataLoader(train_data, batch_size, shuffle=True, seed=0)
test_loader = DataLoader(test_data, batch_size, shuffle=False, seed=0)
Model¶
key = jax.random.key(0)
key, subkey = jax.random.split(key)
model = FNO(
key=subkey,
channels_in=2,
channels_out=1,
n_modes=(16,),
width=64,
depth=4,
)
print(model)
FNO(
lift=MLP(
activation=<function gelu>,
depth=2,
width=128,
hidden_layers=(Linear(weight=f32[128,2], bias=f32[128]),),
output_layer=Linear(weight=f32[64,128], bias=f32[64])
),
project=MLP(
activation=<function gelu>,
depth=2,
width=128,
hidden_layers=(Linear(weight=f32[128,64], bias=f32[128]),),
output_layer=Linear(weight=f32[1,128], bias=f32[1])
),
fourier_layers=(
Fourier(
spectral=SpectralConv(
channels_in=64,
channels_out=64,
n_modes=(16,),
n_dims=1,
weights_re=f32[64,64,16],
weights_im=f32[64,64,16]
),
linear=Linear(weight=f32[64,64], bias=f32[64]),
activation=<function gelu>
),
...
)
)
Verify that a forward pass works:
jax.vmap(model)(train_data["input"][:2]).shape
(2, 64, 1)
Evaluation¶
def relative_l2(prediction: jax.Array, target: jax.Array) -> jax.Array:
return jnp.linalg.norm(prediction - target) / jnp.linalg.norm(target)
@partial(jax.vmap, in_axes=(None, 0, 0))
def loss_fn(model: FNO, x: jax.Array, y: jax.Array) -> jax.Array:
return relative_l2(model(x), y)
@jax.jit
def eval_step(model: FNO, x: jax.Array, y: jax.Array) -> jax.Array:
return jnp.sum(loss_fn(model, x, y))
def evaluate(model: FNO, loader: DataLoader) -> float:
total_loss = jnp.zeros(())
n_samples = 0
for batch in loader:
x, y = batch["input"], batch["output"]
n_samples += x.shape[0]
total_loss += eval_step(model, x, y)
return float(total_loss / n_samples)
Training¶
Create an Adam optimiser with a learning rate that halves every 100 epochs:
learning_rate = 1e-3
steps_per_epoch = len(train_loader)
schedule = optax.schedules.exponential_decay(
learning_rate,
transition_steps=steps_per_epoch * 100,
decay_rate=0.5,
staircase=True,
)
optimiser = optax.adam(schedule)
opt_state = optimiser.init(eqx.filter(model, eqx.is_array))
Define a single training step (batch update):
@eqx.filter_jit
def train_step(
model: FNO,
opt_state: optax.OptState,
optimiser: optax.GradientTransformation,
x: jax.Array,
y: jax.Array,
):
loss, grads = eqx.filter_value_and_grad(
lambda m: jnp.mean(loss_fn(m, x, y))
)(model)
updates, opt_state = optimiser.update(
grads, opt_state, eqx.filter(model, eqx.is_array)
)
model = eqx.apply_updates(model, updates)
return model, opt_state, loss * x.shape[0]
Define a training epoch:
def run_epoch(
model: FNO,
opt_state: optax.OptState,
optimiser: optax.GradientTransformation,
loader: DataLoader,
):
total_loss = jnp.zeros(())
n_samples = 0
for batch in loader:
x, y = batch["input"], batch["output"]
n_samples += x.shape[0]
model, opt_state, batch_loss = train_step(
model, opt_state, optimiser, x, y
)
total_loss += batch_loss
return model, opt_state, float(total_loss / n_samples)
Run the training loop:
n_epochs = 500
for epoch in range(1, n_epochs + 1):
model, opt_state, train_loss = run_epoch(
model, opt_state, optimiser, train_loader
)
train_loader.reset()
if epoch % 10 == 0 or epoch == n_epochs:
test_loss = evaluate(model, test_loader)
test_loader.reset()
print(
f"Epoch {epoch:4d}/{n_epochs} "
f"train={train_loss:.4e} test={test_loss:.4e}"
)
Epoch 10/500 train=1.0623e-01 test=9.8970e-02
Epoch 20/500 train=6.6438e-02 test=6.6549e-02
Epoch 30/500 train=3.9810e-02 test=3.8570e-02
...
Predictions¶
import matplotlib.pyplot as plt
num_examples = 3
n_test = test_data["input"].shape[0]
key, subkey = jax.random.split(key)
random_inds = jax.random.choice(subkey, n_test, (num_examples,))
example_inputs = test_data["input"][random_inds]
example_outputs = test_data["output"][random_inds]
predictions = jax.vmap(model)(example_inputs)
fig, axes = plt.subplots(
2, num_examples, figsize=(16, 9), layout="tight", sharex=True
)
for i in range(num_examples):
idx = random_inds[i]
x = example_inputs[i, :, 0]
u0 = example_inputs[i, :, 1]
v_true = example_outputs[i, ...]
v_pred = predictions[i, ...]
axes[0, i].plot(x, u0)
axes[0, i].set_ylabel("$u_0(x)$", fontsize=10)
axes[0, i].set_title(f"Sample {idx}", fontsize=11)
axes[1, i].plot(x, v_pred, "-", label="Prediction")
axes[1, i].plot(x, v_true, "--", label="Ground truth")
axes[1, i].set_xlabel("$x$", fontsize=10)
axes[1, i].set_ylabel("$u(x, t=1)$", fontsize=10)
axes[1, i].legend(fontsize=10)
plt.show()

Saving and loading models¶
def save_model(path: str, model: FNO, hyperparams: dict) -> None:
"""Save model weights and hyperparameters to a binary file.
The file format follows the Equinox recommended pattern:
a JSON-encoded hyperparameter line followed by the serialised leaves.
"""
with open(path, "wb") as f:
f.write((json.dumps(hyperparams) + "\n").encode())
eqx.tree_serialise_leaves(f, model)
def load_model(path: str):
"""Load a model previously saved with `save_model`."""
with open(path, "rb") as f:
hyperparams = json.loads(f.readline().decode())
hyperparams["n_modes"] = tuple(hyperparams["n_modes"])
skeleton = FNO(key=jax.random.key(0), **hyperparams)
model = eqx.tree_deserialise_leaves(f, skeleton)
return model, hyperparams
hyperparams = {
"channels_in": 2,
"channels_out": 1,
"n_modes": list((16,)),
"width": 64,
"depth": 4,
}
save_model("mymodel.eqx", model, hyperparams)
print("Model saved.")
# Verify that saving and loading works
model_loaded, hyperparams_loaded = load_model("mymodel.eqx")
print("Model loaded. hyperparams:", hyperparams_loaded)
Model saved.
Model loaded. hyperparams: {'channels_in': 2, 'channels_out': 1, 'n_modes': (16,), 'width': 64, 'depth': 4}