JAX-FNO: Fourier Neural Operators in JAX¶
This project provides Fourier neural operators (FNOs) for solving partial differential equations (PDEs) and an ordinary differential equation (ODE) integrator, all written in JAX/Flax.
Overview¶
JAX-FNO is organised into two main modules:
- Learning (
jax_fno.learn): FNO architectures used to learn PDE solution operators - Integration (
jax_fno.integrate): ODE integration methods
Installation¶
Poetry is recommended for installation.
Using Poetry¶
Create a Poetry environment and add the package:
With SSH:
poetry new my-project
cd my-project
poetry add git+ssh://git@github.com/christianfenton/jax-fno.git
With HTTPS:
poetry add git+https://github.com/christianfenton/jax-fno.git
Using pip¶
Alternatively, you can install directly with pip:
With SSH:
pip install git+ssh://git@github.com/christianfenton/jax-fno.git
With HTTPS:
pip install git+https://github.com/christianfenton/jax-fno.git
Getting started¶
To get started using the project, check out the tutorials:
Citations¶
The work in this project is based on
Li, Zongyi, et al. "Fourier neural operator for parametric partial differential equations." arXiv preprint arXiv:2010.08895 (2020). https://arxiv.org/pdf/2010.08895
Links¶
Check out the source code on GitHub.
Future Works¶
In the future, jax_fno.integrate.solve_ivp should be adapted to match
scipy.integrate.solve_ivp
and added to jax.scipy.integrate.