Skip to content

Learning an antiderivative operator

This tutorial demonstrates how to train a Fourier neural operator (FNO) to learn an antiderivative operator.

Problem statement

We'll train an FNO to learn the mapping \(G : u \rightarrow v\), where $$ v(x) = \frac{du}{dx} $$ for $$ u(x) = \sin(\alpha x), \quad x \in [0, 2\pi], \quad \alpha \in [0.5, 1] $$ and where \(u\), \(v\) have periodic boundary conditions.

Spatial discretisation

A uniform finite difference discretisation of the domain leads to grid points $$ x_i = i h, \quad i = 0, \ldots, n-1 $$ where the grid spacing \(h = L / n\) and \(n\) is the number of grid points. The periodic boundary conditions imply that \(x_0\) = \(x_n\).

Implementation

Data generation

Create synthetic training data:

import jax
import jax.numpy as jnp

key = jax.random.key(0)

# Configuration
resolution = 32
n_samples = 120
n_train = 100
n_test = 20

# Domain
L = 2 * jnp.pi
h = L / resolution
grid = jnp.linspace(0, L - h, resolution)

# Generate u and v
alpha = 0.5 * (1 + jax.random.uniform(key, (n_samples,)))
u_data = [jnp.sin(alpha[i] * grid) for i in range(n_samples)]
v_data = [alpha[i] * jnp.cos(alpha[i] * grid) for i in range(n_samples)]

# Stack into tensors
x_coords = jnp.tile(grid, (n_samples, 1))
u = jnp.stack(u_data)
v = jnp.stack(v_data)

# Create input with shape (batch, features, grid)
# Input features: [u(x), x]
input_data = jnp.stack([u, x_coords], axis=-1)
input_data = jnp.permute_dims(input_data, (0, 2, 1))

# Output with shape (batch, features, grid)
output_data = v[:, None, :]

# Split into train and test sets
train_ds = {
    'input': input_data[:n_train], 
    'output': output_data[:n_train]
}
test_ds = {
    'input': input_data[-n_test:], 
    'output': output_data[-n_test:]
}

Create an iterator for batching the data:

from typing import Dict

class DatasetIterator:
    def __init__(
        self, 
        dataset: Dict[str, jnp.ndarray], 
        batch_size: int, 
        shuffle: bool = True
    ):
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.n_samples = dataset['input'].shape[0]
        self.reset()

    def reset(self, key=jax.random.key(0)):
        """Reset iterator with optional reshuffling"""
        self.indices = jnp.arange(self.n_samples)
        if self.shuffle:
            self.indices = jax.random.permutation(key, self.indices)
        self.current_idx = 0

    def __iter__(self):
        return self

    def __next__(self) -> Dict[str, jnp.ndarray]:
        if self.current_idx >= self.n_samples:
            raise StopIteration

        start_idx = self.current_idx
        end_idx = min(start_idx + self.batch_size, self.n_samples)
        batch_indices = self.indices[start_idx:end_idx]

        batch = {
            'input': self.dataset['input'][batch_indices],
            'output': self.dataset['output'][batch_indices]
        }

        self.current_idx = end_idx
        return batch

    def __len__(self):
        return (self.n_samples + self.batch_size - 1) // self.batch_size

Model creation

Create an FNO1D model instance:

from flax import nnx
import jax_fno

model = jax_fno.FNO1D(
    key,
    input_dim=2,            # [u(x), x]
    output_dim=1,           # [v(x)]
    width=64,               # Hidden dimension
    n_modes=12,             # Number of learnable Fourier modes
    n_layers=4,             # Number of Fourier layers
    projection_hidden=128,  # Hidden size in projection layer
)

# Print model structure
nnx.display(model)

Verify forward pass works:

x = train_ds['input'][:2]
y = model(x)
print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")

Training

Create an optax optimiser with an exponentially decaying learning rate:

import optax

# Learning rate schedule
learning_rate = 1e-3
batch_size = 16
steps_per_epoch = (n_train + batch_size - 1) // batch_size

schedule = optax.schedules.exponential_decay(
    learning_rate,
    transition_steps=steps_per_epoch * 100,  # Decay every 100 epochs
    decay_rate=0.5,
    staircase=True
)

optimizer = nnx.Optimizer(model, optax.adam(schedule), wrt=nnx.Param)
metrics = nnx.MultiMetric(loss=nnx.metrics.Average('loss'))

Define loss and training functions:

def l2_loss(pred, target, eps=1e-6):
    """Relative L2 loss"""
    numerator = jnp.sum((pred - target)**2, axis=-1)
    denominator = jnp.sum(target**2, axis=-1) + eps
    return jnp.mean(numerator / denominator)

def loss_fn(model, inputs, targets):
    predictions = model(inputs)
    return l2_loss(predictions, targets)

@nnx.jit
def train_step(model, optimizer, metrics, inputs, targets):
    grad_fn = nnx.value_and_grad(loss_fn)
    loss, grads = grad_fn(model, inputs, targets)
    metrics.update(loss=loss)
    optimizer.update(model, grads)

@nnx.jit
def eval_step(model, metrics, inputs, targets):
    loss = loss_fn(model, inputs, targets)
    metrics.update(loss=loss)

Train the model:

import matplotlib.pyplot as plt
from IPython.display import clear_output

n_epochs = 500
metrics_history = {'train_loss': [], 'test_loss': []}

train_iter = DatasetIterator(train_ds, batch_size=batch_size)
test_iter = DatasetIterator(
    test_ds, batch_size=batch_size, shuffle=False
)

shuffle_key = jax.random.key(0)

for epoch in range(1, n_epochs + 1):
    # Training
    model.train()
    for batch in train_iter:
        train_step(model, optimizer, metrics, batch['input'], batch['output'])

    shuffle_key, subkey = jax.random.split(shuffle_key)
    train_iter.reset(subkey)
    metrics_history['train_loss'].append(metrics.compute()["loss"])

    # Evaluation
    model.eval()
    for batch in test_iter:
        eval_step(model, metrics, batch["input"], batch["output"])

    test_iter.reset()
    metrics_history["test_loss"].append(metrics.compute()["loss"])
    metrics.reset()

    # Plot progress
    if epoch % 10 == 0:
        clear_output(wait=True)
        fig, ax = plt.subplots(figsize=(8, 4.5))
        ax.set_xlabel('Epoch', fontsize=10)
        ax.set_ylabel('Loss', fontsize=10)
        ax.set_yscale('log')
        ax.plot(metrics_history['train_loss'], label='Train')
        ax.plot(metrics_history['test_loss'], label='Test')
        ax.legend()
        ax.set_title(f'Epoch {epoch}/{n_epochs}', fontsize=11)
        plt.show()

Training loss

Evaluation

Make predictions on test set:

model.eval()

n_examples = 3
n_test = test_ds['input'].shape[0]
example_indices = jnp.array([0, n_test//2, n_test-1])

selected_inputs = test_ds['input'][example_indices]
selected_outputs = test_ds['output'][example_indices]
predictions = model(selected_inputs)

# Visualise results
fig, axes = plt.subplots(1, n_examples, figsize=(12, 4), layout='tight')

for i, (ax, idx) in enumerate(zip(axes, example_indices)):
    x = selected_inputs[i, 1, :]
    u0 = selected_inputs[i, 0, :]
    v_true = selected_outputs[i, 0, :]
    v_pred = predictions[i, 0, :]

    # Plot
    ax.plot(x, u0, '-', label='Input $u$', linewidth=2, alpha=0.8)
    ax.plot(x, v_pred, '-', label='Predicted $v$', linewidth=2, alpha=0.8)
    ax.plot(x, v_true, ':', label='Ground truth $v$', linewidth=2, alpha=0.8)

    # Error
    l2_error = jnp.linalg.norm(v_pred - v_true) / jnp.linalg.norm(v_true)
    ax.set_title(f'Sample {idx}, L2 Error: {l2_error:.2e}', fontsize=11)
    ax.set_xlabel('$x$', fontsize=10)
    ax.legend(fontsize=10)

plt.show()

FNO predictions