How NTK Adaptive Weighting and Multi‑Scale Fourier Features Enable PINNs to Solve High‑Frequency PDEs

This tutorial explains why standard physics‑informed neural networks fail on high‑frequency partial differential equations due to spectral bias, and demonstrates how random Fourier feature embeddings, multi‑scale concatenation or spatio‑temporal separation, and Neural Tangent Kernel‑based adaptive loss weighting together overcome the bias and achieve accurate, stable solutions for heat, Poisson, and wave equations using JAX.

AI Agent Research Hub
AI Agent Research Hub
AI Agent Research Hub
How NTK Adaptive Weighting and Multi‑Scale Fourier Features Enable PINNs to Solve High‑Frequency PDEs

Standard PINNs suffer from spectral bias: fully‑connected networks prioritize low‑frequency components, causing high‑frequency PDE solutions to stagnate. The article first reviews the mathematical origin of this bias via the NTK eigenvalue spectrum and cites Wang et al. (2021) and Rahaman et al.

Random Fourier feature (RFF) mappings transform the input space, giving the network a tunable bandwidth. By selecting appropriate frequency scales, the effective NTK becomes more uniform across frequencies, mitigating the bias.

Three multi‑scale strategies are compared:

FF (single‑scale) : a single Fourier matrix applied to all coordinates.

mFF (multi‑scale concatenation) : multiple Fourier branches with different bandwidths are concatenated before the final linear layer.

ST_FF / ST_MFF (spatio‑temporal separation) : time and space coordinates are embedded independently; ST_FF multiplies the two embeddings element‑wise, while ST_MFF concatenates multiple temporal scales and then multiplies with the spatial embedding.

All implementations use JAX; the core forward‑pass code is shown below.

def apply_ff(params, W_ff, tx):
    H = tx.reshape(1, -1)
    H = jnp.concatenate([jnp.sin(H @ W_ff), jnp.cos(H @ W_ff)], axis=1)
    for (W, b) in params[:-1]:
        H = jnp.tanh(H @ W + b)
    W_last, b_last = params[-1]
    return (H @ W_last + b_last)[0, 0]

For spatio‑temporal separation:

def apply_st_ff(params, W_t, W_x, tx):
    t_val = tx[0:1].reshape(1, 1)
    x_val = tx[1:2].reshape(1, 1)
    H_t = jnp.concatenate([jnp.sin(t_val @ W_t), jnp.cos(t_val @ W_t)], axis=1)
    H_x = jnp.concatenate([jnp.sin(x_val @ W_x), jnp.cos(x_val @ W_x)], axis=1)
    for (W, b) in params[:-1]:
        H_t = jnp.tanh(H_t @ W + b)
        H_x = jnp.tanh(H_x @ W + b)
    H = H_t * H_x
    W_last, b_last = params[-1]
    return (H @ W_last + b_last)[0, 0]

Because PINNs involve multiple loss terms (boundary, initial, residual), the convergence speed of each term can differ dramatically. The article adopts the Neural Tangent Kernel (NTK) trace‑based adaptive weighting proposed by Wang et al. (2022). Every 100 training steps the traces of the NTK blocks for each loss are computed, and the loss weights are scaled by the total trace divided by the individual trace, equalising effective learning rates.

tr_u  = float(jnp.trace(K_u))
tr_ut = float(jnp.trace(K_ut))
tr_r  = float(jnp.trace(K_r))
tr_total = tr_u + tr_ut + tr_r
lam_u  = tr_total / tr_u
lam_ut = tr_total / tr_ut
lam_r  = tr_total / tr_r

Extensive experiments on three 1‑D benchmark PDEs illustrate the impact of each component:

Heat equation (high‑frequency spatial mode) : Standard NN and single‑scale FF diverge; ST_FF converges to low error; training time ~360 s.

Poisson equation (mixed low‑ and high‑frequency spatial mode) : NN fails; FF improves slightly; multi‑scale mFF achieves the lowest error (≈1e‑3) in only 130 s.

Wave equation (spatio‑temporal multi‑modal oscillations) : Without adaptive weighting all architectures stall; only ST_MFF with NTK‑based weight adaptation reaches a relative error below 1e‑2.

The results confirm three key findings:

Spectral bias is the fundamental obstacle for high‑frequency PDEs.

Random Fourier features are necessary, but a single bandwidth is insufficient for problems with disparate frequency scales; multi‑scale or spatio‑temporal separation is required.

When multiple loss components are present, NTK‑based adaptive weighting is essential to balance convergence, especially for wave‑type equations.

Limitations include the cost of exact NTK computation (memory‑intensive Jacobians) and the need for a priori frequency selection. Future work suggests learning the Fourier frequencies, hierarchical multi‑scale designs, low‑rank NTK approximations, and extending the approach to high‑dimensional fluid‑dynamics problems.

JAXMulti-ScalePhysics-Informed Neural NetworksFourier FeaturesNeural Tangent KernelPDE Solvers
AI Agent Research Hub
Written by

AI Agent Research Hub

Sharing AI, intelligent agents, and cutting-edge scientific computing

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.