Master Data‑Driven Neural Operators: DeepONet, POD‑DeepONet, FNO‑2D and Time‑Stepping FNO‑1D in One Tutorial

This tutorial presents a comprehensive JAX implementation and analysis of four neural‑operator methods—standard DeepONet, POD‑DeepONet, 2‑D Fourier Neural Operator, and time‑stepping 1‑D FNO—applied to the 1‑D advection equation, comparing their mathematical foundations, parameter counts, training efficiency, and prediction accuracy.

AI Agent Research Hub
AI Agent Research Hub
AI Agent Research Hub
Master Data‑Driven Neural Operators: DeepONet, POD‑DeepONet, FNO‑2D and Time‑Stepping FNO‑1D in One Tutorial

Problem definition

The benchmark is the one‑dimensional advection equation with periodic boundary conditions and a random square‑wave initial condition. The learning task is to approximate the solution operator that maps the initial condition (input function) to the full space‑time solution (output function), enabling a “train once, infer many times” workflow.

Method 1 – DeepONet

DeepONet implements the universal approximation theorem for operators (Chen & Chen, 1995) by decomposing the operator into a branch network that encodes the input function at sensor points and a trunk network that encodes the query location. The output is the inner product of the two latent vectors plus a bias:

class DeepONetCartesianProd(nn.Module):
    @nn.compact
    def __call__(self, u0, xt):
        branch_out = BranchNet(512, 512)(u0)   # (batch, 512)
        trunk_out  = TrunkNet(512, 512)(xt)   # (n_points, 512)
        out = jnp.einsum('bp,np->bn', branch_out, trunk_out)
        bias = self.param('bias', nn.initializers.zeros, (1,))
        return out + bias

The einsum computes the Cartesian product of every sample with every query point, producing the full output field in a single forward pass.

Method 2 – POD‑DeepONet

POD‑DeepONet first applies Proper Orthogonal Decomposition (PCA) to the training outputs, retaining the first 39 modes (99.99 % variance). The high‑dimensional output is thus reduced to a low‑dimensional coefficient vector. The branch network learns the mapping from the input function to these coefficients, and the full solution is reconstructed by multiplying the learned coefficients with the pre‑computed basis functions:

# PCA preprocessing (run once before training)
pca = PCA(n_components=0.9999)
pca.fit(y_train)               # y_train: (1000, 1600)
pod_basis = pca.components_.T   # (1600, 39)
output_mean = pca.mean_          # (1600,)

class PODDeepONet(nn.Module):
    @nn.compact
    def __call__(self, u0):
        branch_out = BranchNet(512, n_components)(u0)   # (batch, 39)
        out = jnp.einsum('bi,ni->bn', branch_out, self.pod_basis)
        return out / self.n_components + self.output_mean[None, :]

Because the basis already captures the dominant variation of the solution space, the learning task becomes substantially easier.

Method 3 – Fourier Neural Operator (FNO‑2D)

FNO lifts the input to a higher‑dimensional channel space, applies four spectral convolution layers (each consisting of a low‑frequency Fourier transform, a learned weight matrix, and an inverse FFT), adds a bypass 1×1 convolution, and finally projects back to the physical space. Only the lowest modes1 × modes2 Fourier coefficients are learned, which yields a global receptive field while keeping the parameter count moderate for modest channel widths.

class SpectralConv2d(nn.Module):
    in_channels: int
    out_channels: int
    modes1: int   # retained frequencies in x direction
    modes2: int   # retained frequencies in t direction

    @nn.compact
    def __call__(self, x):
        # x: (batch, channels, height, width)
        x_ft = jnp.fft.rfft2(x)                     # 2‑D FFT → frequency domain
        out_ft_low = jnp.einsum('bixy,ioxy->boxy',
                                 x_ft[:, :, :self.modes1, :self.modes2],
                                 weights1)   # linear transform of low‑freq modes
        return jnp.fft.irfft2(out_ft_low, s=(height, width))

The full FNO‑2D architecture follows the pattern Lifting → 4×(SpectralConv2d + Bypass + ReLU) → Projection .

Method 4 – Time‑stepping FNO (FNO‑1D)

FNO‑1D treats the time dimension autoregressively. A single‑step 1‑D FNO acts as a propagator; during training the 39 time steps are unrolled, and during inference the model repeatedly feeds its own prediction as the next input.

def autoregressive_predict(params, x_input):
    u_current = x_input[:, :, 0]   # initial condition (batch, 40)
    x_c = x_input[:, :, 1]         # spatial coordinate (batch, 40)
    predictions = [u_current]
    for t_step in range(39):      # 39 forward steps
        inp = jnp.stack([u_current, x_c], axis=-1)   # (batch, 40, 2)
        u_next = model.apply({'params': params}, inp)  # single‑step FNO
        predictions.append(u_next)
        u_current = u_next
    return jnp.stack(predictions, axis=-1)  # (batch, 40, 40)

The full unrolled graph makes gradient propagation expensive, which explains the longer training time.

Core characteristics comparison

Core idea : DeepONet – branch‑trunk inner product; POD‑DeepONet – PCA basis + coefficient learning; FNO‑2D – 2‑D spectral convolution; FNO‑1D – 1‑D spectral convolution + autoregressive rollout.

Output form : DeepONet – pointwise scalar; POD‑DeepONet – full‑field vector; FNO‑2D – full 2‑D field; FNO‑1D – sequential 1‑D fields.

Training scheme : DeepONet & POD‑DeepONet – full‑batch, 100 k epochs; FNO‑2D & FNO‑1D – mini‑batch (size 20), 500 epochs.

Query flexibility : DeepONet – arbitrary points; all others – fixed spatial grid.

Resolution invariance : only the FNO family.

Parameter and efficiency analysis

DeepONet: 810,497 parameters, 46.1 s training (100 k epochs).

POD‑DeepONet: 40,999 parameters, 25.7 s training (100 k epochs) – 5.1 % of DeepONet size .

FNO‑2D: 16,835,585 parameters, 173.3 s training (500 epochs) – 410 × POD‑DeepONet .

FNO‑1D: 139,745 parameters, 450.3 s training (500 epochs) – moderate size but high time due to autoregression.

Experimental results

Accuracy vs. cost

POD‑DeepONet achieves the lowest test relative error (approximately 28 × better than standard DeepONet) with the shortest training time. FNO‑2D attains good accuracy but at a massive parameter cost. FNO‑1D shows reasonable accuracy but suffers from error accumulation over time steps.

Convergence behaviour

POD‑DeepONet stabilises after ~20 k training steps, while DeepONet converges very slowly even after 100 k steps. Both FNO variants descend quickly within the 500‑epoch mini‑batch regime.

Prediction visualisation

POD‑DeepONet reproduces the reference solution almost perfectly, with uniformly tiny error.

FNO‑2D reconstructs the overall shape well but exhibits slight Gibbs‑type oscillations near the square‑wave discontinuity.

FNO‑1D’s error grows toward later time steps because of autoregressive error propagation.

DeepONet shows noticeable diffusion around the wave front, reflecting limited expressive power for sharp discontinuities.

Efficiency‑accuracy trade‑off

When measuring error per parameter, FNO‑2D is the most efficient only because of its huge parameter budget. POD‑DeepONet delivers the best overall cost‑performance ratio, achieving the lowest error with the fewest parameters and the shortest training time. DeepONet provides the poorest time‑efficiency.

Conclusions

Data‑driven priors (PCA) are crucial for problems with low‑rank solution manifolds – POD‑DeepONet is the preferred choice.

Global modelling via Fourier transforms (FNO) excels when resolution‑invariance is required, at the expense of parameter explosion.

Autoregressive schemes (FNO‑1D) introduce significant training overhead; alternatives such as jax.lax.scan can reduce compilation cost.

Method selection should consider the underlying solution structure, desired query flexibility, and hardware constraints.

References

[1] Lu, L., Meng, X., Cai, S., Mao, Z., Goswami, S., Zhang, Z., & Karniadakis, G. E. (2022). A comprehensive and fair comparison of two neural operators (with practical extensions) based on FAIR data. Computer Methods in Applied Mechanics and Engineering , 393, 114778.

[2] Lu, L., Jin, P., Pang, G., Zhang, Z., & Karniadakis, G. E. (2021). Learning nonlinear operators via DeepONet based on the universal approximation theorem of operators. Nature Machine Intelligence , 3(3), 218‑229.

[3] Li, Z., Kovachki, N., Azizzadenesheli, K., Liu, B., Bhattacharya, K., Stuart, A., & Anandkumar, A. (2021). Fourier Neural Operator for Parametric Partial Differential Equations. International Conference on Learning Representations (ICLR) .

[4] Chen, T., & Chen, H. (1995). Universal approximation to nonlinear operators by neural networks with arbitrary activation functions and its application to dynamical systems. IEEE Transactions on Neural Networks , 6(4), 911‑917.

[5] Bradbury, J., Frostig, R., Hawkins, P., Johnson, M. J., Leary, C., Maclaurin, D., … & Zhang, Q. (2018). JAX: composable transformations of Python+NumPy programs.

http://github.com/jax-ml/jax
Original Source

Signed-in readers can open the original source through BestHub's protected redirect.

Sign in to view source
Republication Notice

This article has been distilled and summarized from source material, then republished for learning and reference. If you believe it infringes your rights, please contactadmin@besthub.devand we will review it promptly.

JAXPDEDeepONetneural operatorsFNOPOD‑DeepONet
AI Agent Research Hub
Written by

AI Agent Research Hub

Sharing AI, intelligent agents, and cutting-edge scientific computing

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.