Skip to content

JAX-FNO: Fourier Neural Operators in JAX

This project provides Fourier neural operators (FNOs) for solving partial differential equations (PDEs) and an ordinary differential equation (ODE) integrator, all written in JAX/Flax.

Overview

JAX-FNO is organised into two main modules:

  • Learning (jax_fno.learn): FNO architectures used to learn PDE solution operators
  • Integration (jax_fno.integrate): ODE integration methods

Installation

Poetry is recommended for installation.

Using Poetry

Create a Poetry environment and add the package:

With SSH:

poetry new my-project
cd my-project
poetry add git+ssh://git@github.com/christianfenton/jax-fno.git

With HTTPS:

poetry add git+https://github.com/christianfenton/jax-fno.git

Using pip

Alternatively, you can install directly with pip:

With SSH:

pip install git+ssh://git@github.com/christianfenton/jax-fno.git

With HTTPS:

pip install git+https://github.com/christianfenton/jax-fno.git

Getting started

To get started using the project, check out the tutorials:

Citations

The work in this project is based on

Li, Zongyi, et al. "Fourier neural operator for parametric partial differential equations." arXiv preprint arXiv:2010.08895 (2020). https://arxiv.org/pdf/2010.08895

Check out the source code on GitHub.

Future Works

In the future, jax_fno.integrate.solve_ivp should be adapted to match scipy.integrate.solve_ivp and added to jax.scipy.integrate.