Master Data‑Driven Neural Operators: DeepONet, POD‑DeepONet, FNO‑2D and Time‑Stepping FNO‑1D in One Tutorial
This tutorial presents a comprehensive JAX implementation and analysis of four neural‑operator methods—standard DeepONet, POD‑DeepONet, 2‑D Fourier Neural Operator, and time‑stepping 1‑D FNO—applied to the 1‑D advection equation, comparing their mathematical foundations, parameter counts, training efficiency, and prediction accuracy.
Problem definition
The benchmark is the one‑dimensional advection equation with periodic boundary conditions and a random square‑wave initial condition. The learning task is to approximate the solution operator that maps the initial condition (input function) to the full space‑time solution (output function), enabling a “train once, infer many times” workflow.
Method 1 – DeepONet
DeepONet implements the universal approximation theorem for operators (Chen & Chen, 1995) by decomposing the operator into a branch network that encodes the input function at sensor points and a trunk network that encodes the query location. The output is the inner product of the two latent vectors plus a bias:
class DeepONetCartesianProd(nn.Module):
@nn.compact
def __call__(self, u0, xt):
branch_out = BranchNet(512, 512)(u0) # (batch, 512)
trunk_out = TrunkNet(512, 512)(xt) # (n_points, 512)
out = jnp.einsum('bp,np->bn', branch_out, trunk_out)
bias = self.param('bias', nn.initializers.zeros, (1,))
return out + biasThe einsum computes the Cartesian product of every sample with every query point, producing the full output field in a single forward pass.
Method 2 – POD‑DeepONet
POD‑DeepONet first applies Proper Orthogonal Decomposition (PCA) to the training outputs, retaining the first 39 modes (99.99 % variance). The high‑dimensional output is thus reduced to a low‑dimensional coefficient vector. The branch network learns the mapping from the input function to these coefficients, and the full solution is reconstructed by multiplying the learned coefficients with the pre‑computed basis functions:
# PCA preprocessing (run once before training)
pca = PCA(n_components=0.9999)
pca.fit(y_train) # y_train: (1000, 1600)
pod_basis = pca.components_.T # (1600, 39)
output_mean = pca.mean_ # (1600,)
class PODDeepONet(nn.Module):
@nn.compact
def __call__(self, u0):
branch_out = BranchNet(512, n_components)(u0) # (batch, 39)
out = jnp.einsum('bi,ni->bn', branch_out, self.pod_basis)
return out / self.n_components + self.output_mean[None, :]Because the basis already captures the dominant variation of the solution space, the learning task becomes substantially easier.
Method 3 – Fourier Neural Operator (FNO‑2D)
FNO lifts the input to a higher‑dimensional channel space, applies four spectral convolution layers (each consisting of a low‑frequency Fourier transform, a learned weight matrix, and an inverse FFT), adds a bypass 1×1 convolution, and finally projects back to the physical space. Only the lowest modes1 × modes2 Fourier coefficients are learned, which yields a global receptive field while keeping the parameter count moderate for modest channel widths.
class SpectralConv2d(nn.Module):
in_channels: int
out_channels: int
modes1: int # retained frequencies in x direction
modes2: int # retained frequencies in t direction
@nn.compact
def __call__(self, x):
# x: (batch, channels, height, width)
x_ft = jnp.fft.rfft2(x) # 2‑D FFT → frequency domain
out_ft_low = jnp.einsum('bixy,ioxy->boxy',
x_ft[:, :, :self.modes1, :self.modes2],
weights1) # linear transform of low‑freq modes
return jnp.fft.irfft2(out_ft_low, s=(height, width))The full FNO‑2D architecture follows the pattern Lifting → 4×(SpectralConv2d + Bypass + ReLU) → Projection .
Method 4 – Time‑stepping FNO (FNO‑1D)
FNO‑1D treats the time dimension autoregressively. A single‑step 1‑D FNO acts as a propagator; during training the 39 time steps are unrolled, and during inference the model repeatedly feeds its own prediction as the next input.
def autoregressive_predict(params, x_input):
u_current = x_input[:, :, 0] # initial condition (batch, 40)
x_c = x_input[:, :, 1] # spatial coordinate (batch, 40)
predictions = [u_current]
for t_step in range(39): # 39 forward steps
inp = jnp.stack([u_current, x_c], axis=-1) # (batch, 40, 2)
u_next = model.apply({'params': params}, inp) # single‑step FNO
predictions.append(u_next)
u_current = u_next
return jnp.stack(predictions, axis=-1) # (batch, 40, 40)The full unrolled graph makes gradient propagation expensive, which explains the longer training time.
Core characteristics comparison
Core idea : DeepONet – branch‑trunk inner product; POD‑DeepONet – PCA basis + coefficient learning; FNO‑2D – 2‑D spectral convolution; FNO‑1D – 1‑D spectral convolution + autoregressive rollout.
Output form : DeepONet – pointwise scalar; POD‑DeepONet – full‑field vector; FNO‑2D – full 2‑D field; FNO‑1D – sequential 1‑D fields.
Training scheme : DeepONet & POD‑DeepONet – full‑batch, 100 k epochs; FNO‑2D & FNO‑1D – mini‑batch (size 20), 500 epochs.
Query flexibility : DeepONet – arbitrary points; all others – fixed spatial grid.
Resolution invariance : only the FNO family.
Parameter and efficiency analysis
DeepONet: 810,497 parameters, 46.1 s training (100 k epochs).
POD‑DeepONet: 40,999 parameters, 25.7 s training (100 k epochs) – 5.1 % of DeepONet size .
FNO‑2D: 16,835,585 parameters, 173.3 s training (500 epochs) – 410 × POD‑DeepONet .
FNO‑1D: 139,745 parameters, 450.3 s training (500 epochs) – moderate size but high time due to autoregression.
Experimental results
Accuracy vs. cost
POD‑DeepONet achieves the lowest test relative error (approximately 28 × better than standard DeepONet) with the shortest training time. FNO‑2D attains good accuracy but at a massive parameter cost. FNO‑1D shows reasonable accuracy but suffers from error accumulation over time steps.
Convergence behaviour
POD‑DeepONet stabilises after ~20 k training steps, while DeepONet converges very slowly even after 100 k steps. Both FNO variants descend quickly within the 500‑epoch mini‑batch regime.
Prediction visualisation
POD‑DeepONet reproduces the reference solution almost perfectly, with uniformly tiny error.
FNO‑2D reconstructs the overall shape well but exhibits slight Gibbs‑type oscillations near the square‑wave discontinuity.
FNO‑1D’s error grows toward later time steps because of autoregressive error propagation.
DeepONet shows noticeable diffusion around the wave front, reflecting limited expressive power for sharp discontinuities.
Efficiency‑accuracy trade‑off
When measuring error per parameter, FNO‑2D is the most efficient only because of its huge parameter budget. POD‑DeepONet delivers the best overall cost‑performance ratio, achieving the lowest error with the fewest parameters and the shortest training time. DeepONet provides the poorest time‑efficiency.
Conclusions
Data‑driven priors (PCA) are crucial for problems with low‑rank solution manifolds – POD‑DeepONet is the preferred choice.
Global modelling via Fourier transforms (FNO) excels when resolution‑invariance is required, at the expense of parameter explosion.
Autoregressive schemes (FNO‑1D) introduce significant training overhead; alternatives such as jax.lax.scan can reduce compilation cost.
Method selection should consider the underlying solution structure, desired query flexibility, and hardware constraints.
References
[1] Lu, L., Meng, X., Cai, S., Mao, Z., Goswami, S., Zhang, Z., & Karniadakis, G. E. (2022). A comprehensive and fair comparison of two neural operators (with practical extensions) based on FAIR data. Computer Methods in Applied Mechanics and Engineering , 393, 114778.
[2] Lu, L., Jin, P., Pang, G., Zhang, Z., & Karniadakis, G. E. (2021). Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nature Machine Intelligence , 3(3), 218‑229.
[3] Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A. (2021). Fourier Neural Operator for Parametric Partial Differential Equations. International Conference on Learning Representations (ICLR) .
[4] Chen, T., & Chen, H. (1995). Universal approximation to nonlinear operators by neural networks with arbitrary activation functions and its application to dynamical systems. IEEE Transactions on Neural Networks , 6(4), 911‑917.
[5] Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., … & Zhang, Q. (2018). JAX: composable transformations of Python+NumPy programs.
http://github.com/jax-ml/jaxSigned-in readers can open the original source through BestHub's protected redirect.
This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactand we will review it promptly.
AI Agent Research Hub
Sharing AI, intelligent agents, and cutting-edge scientific computing
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
