Adaptive-Weight NTK-PINN Solves High-Frequency Wave Equation Using JAX

This tutorial explains how the Neural Tangent Kernel (NTK) perspective reveals the loss‑balance problem in Physics‑Informed Neural Networks (PINNs), introduces an NTK‑based adaptive‑weight algorithm, provides a full JAX implementation for a 1‑D high‑frequency wave equation, and shows that input normalisation dramatically improves accuracy while only modestly increasing training time.

AI Agent Research Hub
AI Agent Research Hub
AI Agent Research Hub
Adaptive-Weight NTK-PINN Solves High-Frequency Wave Equation Using JAX

Abstract

Physics‑Informed Neural Networks (PINNs) embed PDE residuals into a composite loss. Standard PINNs suffer from severe convergence‑rate imbalance because the gradient scales of the boundary/initial‑condition terms and the PDE‑residual term differ by several orders of magnitude. Wang et al. [1] showed that, in the infinite‑width limit, the trace ratio of the NTK sub‑matrices associated with each loss component determines its effective learning speed. An NTK‑adaptive‑weight algorithm rescales the loss coefficients so that the three components converge at the same rate. This summary reproduces the algorithm in JAX, applies it to a 1‑D high‑frequency wave equation, and quantifies the effect of input z‑score normalisation.

1. Introduction

1.1 PINN basics

PINNs approximate the solution \(u(t,x)\) of a PDE by a neural network \(\hat u_{\theta}(t,x)\) and minimise a weighted sum of four loss components: PDE residual, initial displacement, initial velocity, and boundary conditions.

1.2 Multi‑task loss imbalance

The PDE residual involves second‑order derivatives whose gradients decay exponentially with network depth, while the boundary/initial‑condition terms involve only zero‑order outputs. Consequently, early training over‑fits the boundary conditions and the PDE residual converges very slowly, producing non‑physical solutions.

1.3 NTK perspective

Wang et al. [1] proved that, in the infinite‑width limit, training dynamics are equivalent to kernel regression with the Neural Tangent Kernel (NTK) \(K = JJ^{\top}\), where \(J\) is the Jacobian of the network output w.r.t. parameters. The eigenvalues of \(K\) directly control the convergence speed of each loss direction; the trace ratio of the sub‑matrices associated with each loss component therefore quantifies the imbalance.

2. Method and Theory

2.1 Wave‑equation benchmark

The benchmark PDE is \(u_{tt}=c^{2}u_{xx}\) on \([0,T]\times[0,X]\) with an exact solution that contains a low‑frequency mode and a high‑frequency mode whose spatial and temporal frequencies are four times larger. The total loss is a weighted sum of the four components described above.

2.2 Neural Tangent Kernel basics

For a network \(f_{\theta}(\mathbf{x})\) the NTK is defined as \(K(\mathbf{x},\mathbf{x}') = J(\mathbf{x})J(\mathbf{x}')^{\top}\), where \(J(\mathbf{x}) = \partial f_{\theta}(\mathbf{x})/\partial \theta\). Under continuous‑time gradient descent with mean‑squared loss the residual evolves as \(\dot{\mathbf{r}} = -K\mathbf{r}\). Decomposing \(K\) into eigen‑pairs shows that each eigenvalue \(\lambda_i\) dictates the exponential decay rate of the corresponding error mode.

2.3 NTK for PINNs

Because the PINN loss contains four distinct output operators, the full NTK matrix has a block structure. The diagonal blocks \(K_u, K_{u_t}, K_r\) correspond to the Jacobians of the boundary/initial‑condition outputs and the PDE‑residual output, respectively. Their traces \(\mathrm{tr}(K_i)\) determine the effective learning speed of each loss component.

2.4 Adaptive‑weight algorithm

To enforce equal effective rates we require \(\lambda_u = \lambda_{u_t} = \lambda_r = \lambda\). Using the NTK‑based convergence analysis this yields the closed‑form weight update

lam_i = trace_total / trace_K_i   # i ∈ {u, u_t, r}

where \(trace\_total = \mathrm{tr}(K_u)+\mathrm{tr}(K_{u_t})+\mathrm{tr}(K_r)\).

Training procedure (performed every step unless noted):

Sample a batch of interior (PDE), initial‑condition and boundary points.

Compute the four loss components and back‑propagate.

Every 100 steps, sample NTK points and compute Jacobians.

Assemble the three NTK sub‑matrices and compute their traces.

Update the three weights using \(lam_i = trace\_total / trace\_K_i\).

Repeat until convergence.

2.5 Input normalisation and chain‑rule correction

Physical coordinates \((t,x)\) are transformed to z‑score normalised coordinates \((t_{norm},x_{norm})\). Because automatic differentiation in JAX operates on the normalised coordinates, the PDE residual must be multiplied by the scaling factors \(\sigma_T^{-2}\) and \(\sigma_X^{-2}\) to recover the physical derivatives.

3. JAX implementation

3.1 Network architecture

A four‑layer fully‑connected network with layer sizes [2, 500, 500, 500, 1], tanh activation and Xavier initialisation is used. The total number of trainable parameters is 502 001.

Network Architecture and Dimension Flow
Network Architecture and Dimension Flow

3.2 Core code blocks

3.2.1 PDE residual

def net_residual_single(params, t_norm, x_norm):
    du_dt_norm = grad(net_u_single, argnums=1)
    du_dx_norm = grad(net_u_single, argnums=2)
    d2u_dt2_norm = grad(du_dt_norm, argnums=1)
    d2u_dx2_norm = grad(du_dx_norm, argnums=2)
    u_tt_phys = d2u_dt2_norm(params, t_norm, x_norm) / (SIGMA_T ** 2)
    u_xx_phys = d2u_dx2_norm(params, t_norm, x_norm) / (SIGMA_X ** 2)
    return u_tt_phys - C_PARAM ** 2 * u_xx_phys

The function computes second‑order derivatives on the normalised coordinates, then rescales them to physical space using the factors \(\sigma_T^{-2}\) and \(\sigma_X^{-2}\).

3.2.2 Jacobian computation

def compute_jacobian_r(params, t_pts, x_pts):
    flat_params, unravel = ravel_pytree(params)
    def f_flat(fp):
        return net_residual_batch(unravel(fp), t_pts, x_pts)
    return jacrev(f_flat)(flat_params)

\(ravel\_pytree\) flattens the parameter PyTree, \(jacrev\) computes the Jacobian via reverse‑mode AD (efficient because the number of sampled points is far smaller than the number of parameters).

3.2.3 NTK assembly

def compute_ntk_diag_blocks(params, t_bc_n, x_bc_n, t_ic_n, x_ic_n, t_r_n, x_r_n):
    J_u  = compute_jacobian_u(params, t_bc_n, x_bc_n)   # (N_u, P)
    J_ut = compute_jacobian_ut(params, t_ic_n, x_ic_n)  # (N_ut, P)
    J_r  = compute_jacobian_r(params, t_r_n, x_r_n)    # (N_r, P)
    K_u  = J_u @ J_u.T   # (N_u, N_u)
    K_ut = J_ut @ J_ut.T # (N_ut, N_ut)
    K_r  = J_r @ J_r.T   # (N_r, N_r)
    return K_u, K_ut, K_r

3.2.4 Adaptive weight update

trace_K_u  = np.trace(K_u_np)
trace_K_ut = np.trace(K_ut_np)
trace_K_r  = np.trace(K_r_np)
trace_total = trace_K_u + trace_K_ut + trace_K_r
if trace_K_u > 0 and trace_K_ut > 0 and trace_K_r > 0:
    lam_u  = float(trace_total / trace_K_u)
    lam_ut = float(trace_total / trace_K_ut)
    lam_r  = float(trace_total / trace_K_r)

The safety check prevents division by zero.

4. Experiments

4.1 Benchmark: 1‑D wave equation

The exact solution contains a low‑frequency mode and a high‑frequency mode whose spatial and temporal frequencies are four times larger. Standard PINNs struggle to capture the high‑frequency component.

4.2 Settings

Network: [2,500,500,500,1] (502 001 parameters)

Optimizer: Adam with exponential decay (decay factor 0.9 every 1 000 steps)

Training steps: 80 001 (including step 0)

Batch size: 300

NTK update interval: every 100 steps, NTK sampled on 300 points

Input normalisation: z‑score (mean and std computed from uniformly sampled points)

4.3 Results

4.3.1 Loss curves and L2 error

Loss curves and L2 error
Loss curves and L2 error

With NTK‑adaptive weights the PDE residual loss (blue) and the boundary loss (orange) stay at comparable magnitudes, preventing one term from dominating. The relative L2 error begins to drop steadily after ~20 k steps and reaches the final reported value (≈4.8 × 10⁻³).

4.3.2 Prediction visualisation

Prediction comparison
Prediction comparison

The NTK‑PINN solution reproduces both low‑ and high‑frequency modes; the largest absolute errors appear near the peaks of the high‑frequency component.

4.3.3 NTK eigenvalue evolution

NTK eigenvalues
NTK eigenvalues

The eigen‑value spectrum of the PDE‑residual block decays fastest, while the boundary/initial‑condition blocks retain larger eigenvalues, confirming the theoretical cause of the imbalance.

4.3.4 Adaptive weight dynamics

Adaptive weights
Adaptive weights

All three weights start near 1 because the PDE‑residual NTK trace dominates the total trace. Early in training the weights for the boundary terms become tens of times larger to compensate their smaller NTK traces, then gradually stabilise.

4.3.5 Impact of normalisation

Best L2 error reduced by 54 % for the wave benchmark and by 68 % for a Poisson benchmark.

Training time increased by ~17 % for the wave case (588 s → 688 s) and ~9 % for the Poisson case (41.5 s → 45.1 s), mainly due to the extra normalisation calculations.

5. Conclusions and Outlook

NTK‑based adaptive weighting successfully balances the multi‑task loss in PINNs, enabling accurate solution of high‑frequency PDEs.

Simple z‑score normalisation of inputs yields substantial accuracy gains without changing the network architecture.

The NTK eigen‑value spectra provide a diagnostic tool for understanding and correcting loss‑imbalance.

Limitations: NTK computation scales quadratically with the number of parameters and sampled points, and the NTK theory assumes infinite width, which does not hold exactly for finite networks.

Future directions include efficient NTK approximations (e.g., Hutchinson trace estimator), comparison with other adaptive schemes (GradNorm, causal weighting), extension to higher‑dimensional and nonlinear PDEs, and theoretical analysis of finite‑width effects.

6. Frequently Asked Questions

Q1: NTK computation runs out of memory

Each Jacobian has size \(N \times P\). For the wave benchmark \(N\approx300\) and \(P=502{,}001\), a single Jacobian occupies ~572 MB in single precision. Three Jacobians plus the NTK matrices quickly exceed GPU memory.

Solutions:

Reduce the NTK sampling size (parameter KERNEL_SIZE).

Use jax.checkpoint to recompute intermediates instead of storing them.

Replace explicit NTK computation with stochastic trace estimators.

Q2: L2 error does not decrease in early training

The adaptive weights need a few thousand steps to stabilise; large early fluctuations can cause alternating dominance of loss components.

Solutions:

Increase the total number of training steps.

Lower the learning rate to reduce weight oscillations.

Apply smoothing to the weight updates (e.g., exponential moving average).

Q3: Gradient values are incorrect after normalisation

Forgetting to multiply the physical derivatives by the normalisation correction factors leads to wrong residuals.

Solution: Follow the formulas in Section 2.5 and verify each derivative with a simple analytical test case.

Q4: When to use jax.jacrev vs jax.jacfwd ?

Use jax.jacrev (reverse‑mode) when the number of sampled points is far smaller than the number of parameters (the usual PINN setting).

Use jax.jacfwd (forward‑mode) when the output dimension exceeds the parameter dimension.

Q5: Why Adam instead of SGD?

Adam’s adaptive learning‑rate complements the NTK‑adaptive weighting, is more widely used in PINN literature, and helps avoid local minima in the multi‑frequency wave problem.

JAXPINNNTKadaptive weightinghigh-frequency PDE
AI Agent Research Hub
Written by

AI Agent Research Hub

Sharing AI, intelligent agents, and cutting-edge scientific computing

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.