Why Do Physics-Informed Neural Networks Often Fail? Try Gradient‑Adaptive Weighted PINNs
Training Physics‑Informed Neural Networks (PINNs) often stalls due to gradient‑flow pathology, but a gradient‑adaptive weighting scheme dramatically rebalances loss contributions, as demonstrated on a nonlinear Klein‑Gordon equation where error drops by 95.5% compared with standard PINNs.
Gradient‑flow pathology
Physics‑Informed Neural Networks (PINNs) minimise a multi‑task loss that combines the PDE residual loss L_{res}, the initial‑condition loss L_{ics} and the boundary‑condition loss L_{bcs}. When the gradient magnitudes of these components differ by orders of magnitude, the optimiser follows the dominant gradient (usually the residual) and the weaker IC/BC gradients are effectively ignored. This imbalance is called gradient‑flow pathology .
Gradient‑weighted adaptive loss
Wang et al. [1] propose an adaptive weighting scheme based on simple gradient statistics. For each training step the following quantities are computed for every weight matrix W of the network:
max_res = max(|∇_W L_{res}|) mean_ics = mean(|∇_W L_{ics}|) mean_bcs = mean(|∇_W L_{bcs}|)The global maximum of max_res and the global means of mean_ics and mean_bcs are then used to form instantaneous weights
λ_icŝ = max_res / (mean_ics + 1e-10)
λ_bcŝ = max_res / (mean_bcs + 1e-10)These raw weights are smoothed with an exponential moving average (EMA)
λ_ics = β·λ_ics + (1‑β)·λ_icŝ
λ_bcs = β·λ_bcs + (1‑β)·λ_bcŝTypical EMA coefficient is β=0.9, which retains roughly ten steps of history and prevents oscillations.
Network architecture
The experiments use a fully‑connected MLP with five hidden layers, each of width 50 and tanh activation. Xavier (Glorot) initialisation is applied to all weight matrices and biases are initialised to zero. The total number of trainable parameters is 10 401 (10 150 weights + 251 biases).
Core JAX implementation
Residual computation (second‑order time and space derivatives) and adaptive‑weight calculation are implemented as:
def residual_single(params, t_n, x_n, f_val, sigma_t, sigma_x):
u = net_u_scalar(params, t_n, x_n)
u_t = grad(net_u_scalar, 1)(params, t_n, x_n) / sigma_t
u_tt = grad(lambda p, tn, xn: grad(net_u_scalar, 1)(p, tn, xn) / sigma_t,
argnums=1)(params, t_n, x_n) / sigma_t
u_xx = grad(lambda p, tn, xn: grad(net_u_scalar, 2)(p, tn, xn) / sigma_x,
argnums=2)(params, t_n, x_n) / sigma_x
pde = u_tt + ALPHA * u_xx + BETA_PDE * u + GAMMA * u ** K_EXP
return pde - f_val def compute_adaptive_lambdas(params, t_r, x_r, f_r, t_ic, x_ic, u_ic,
t_bc, x_bc, u_bc, sigma_t, sigma_x,
lam_ics_cur, lam_bcs_cur):
grads_res = grad(loss_res_fn)(params, t_r, x_r, f_r, sigma_t, sigma_x)
def weighted_ics(p, t, x, u, s_t):
return lam_ics_cur * loss_ics_fn(p, t, x, u, s_t)
grads_ics = grad(weighted_ics)(params, t_ic, x_ic, u_ic, sigma_t)
def weighted_bcs(p, t, x, u):
return lam_bcs_cur * loss_bcs_fn(p, t, x, u)
grads_bcs = grad(weighted_bcs)(params, t_bc, x_bc, u_bc)
max_res_list, mean_ics_list, mean_bcs_list = [], [], []
for g_r, g_i, g_b in zip(grads_res, grads_ics, grads_bcs):
max_res_list.append(jnp.max(jnp.abs(g_r['w'])))
mean_ics_list.append(jnp.mean(jnp.abs(g_i['w'])))
mean_bcs_list.append(jnp.mean(jnp.abs(g_b['w'])))
max_grad_res = jnp.max(jnp.array(max_res_list))
mean_grad_ics = jnp.mean(jnp.array(mean_ics_list))
mean_grad_bcs = jnp.mean(jnp.array(mean_bcs_list))
lam_ics_hat = max_grad_res / (mean_grad_ics + 1e-10)
lam_bcs_hat = max_grad_res / (mean_grad_bcs + 1e-10)
return lam_ics_hat, lam_bcs_hatTraining loops
Two variants are defined:
# M1: fixed weights (λ_i = λ_b = 1)
@jit
def train_step_m1(params, opt_state, t_r, x_r, f_r,
t_ic, x_ic, u_ic, t_bc, x_bc, u_bc):
lam_i = jnp.float32(1.0)
lam_b = jnp.float32(1.0)
(loss, (l_res, l_ics, l_bcs)), grads = \
jax.value_and_grad(loss_total_fn, has_aux=True)(
params, t_r, x_r, f_r, t_ic, x_ic, u_ic,
t_bc, x_bc, u_bc, sigma_t, sigma_x, lam_i, lam_b)
updates, new_opt = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt, loss, l_res, l_ics, l_bcs
# M2: adaptive weights updated every 10 steps
@jit
def train_step_m2(params, opt_state, t_r, x_r, f_r,
t_ic, x_ic, u_ic, t_bc, x_bc, u_bc,
lam_ics, lam_bcs):
(loss, (l_res, l_ics, l_bcs)), grads = \
jax.value_and_grad(loss_total_fn, has_aux=True)(
params, t_r, x_r, f_r, t_ic, x_ic, u_ic,
t_bc, x_bc, u_bc, sigma_t, sigma_x, lam_ics, lam_bcs)
updates, new_opt = optimizer.update(grads, opt_state, params)
new_params = optax.apply_updates(params, updates)
return new_params, new_opt, loss, l_res, l_ics, l_bcsKey differences:
Weight update: fixed (M1) vs. EMA‑smoothed adaptive update every 10 steps (M2).
Extra gradient passes: none for M1; three additional backward passes for M2.
Computational overhead: baseline for M1; ≈5‑10 % increase for M2.
Experimental setup
Benchmark problem: nonlinear Klein‑Gordon equation with parameters ALPHA, BETA_PDE, GAMMA, K_EXP (exact values omitted for brevity). The domain is sampled with 128 points for each of residual, IC and BC batches. Training runs for 40 001 iterations using the Adam optimiser with exponential learning‑rate decay (rate = 0.9 every 1 000 steps). EMA coefficient is 0.9, weight‑update frequency is 10 steps, and the random seed is 1234.
Results
Solution quality : Visual comparison (figures omitted) shows that the standard PINN (M1) deviates strongly across the domain, while the gradient‑weighted PINN (M2) matches the exact solution almost everywhere, with errors confined to tiny local regions.
Quantitative error reduction : M2 achieves a 95.5 % decrease in relative L2 error compared with M1.
Loss convergence (average of the three components):
Residual loss reduced by 80.9 %.
Initial‑condition loss reduced by 94.0 %.
Boundary‑condition loss reduced by 99.9 % (approximately three orders of magnitude).
Relative‑error evolution : M1’s error stalls after ≈5 000 steps; M2’s error drops rapidly in the first 5 000 steps and then continues a slower convergence, ending at a much lower final error.
Adaptive weight evolution : During training the boundary‑condition weight grows to roughly 200‑300 × its initial value, indicating that the raw BC gradient is two or more orders of magnitude weaker than the residual gradient. The IC weight also increases but stabilises earlier.
Comparison with NTK‑based adaptive weighting
Mathematical basis: gradient statistics (max/mean) vs. NTK eigenvalues.
Computational cost: one extra backward pass vs. kernel‑matrix computation.
Implementation difficulty: low vs. medium.
Physical interpretability: direct gradient‑imbalance signal vs. effective learning‑rate of each task.
Update frequency: flexible (per step or interval) vs. usually large intervals.
Conclusions
Gradient‑flow pathology is the primary cause of training failure for PINNs on high‑frequency PDEs.
Gradient‑weighted adaptive loss dramatically improves accuracy with negligible extra cost (≈5‑10 % overhead, three extra back‑propagations every 10 steps).
The magnitude of the learned weights reveals the severity of gradient imbalance (boundary‑condition weight ≈ 200‑300×).
Limitations and future work
Using only max and mean may miss finer details of the gradient distribution.
EMA coefficient is problem‑dependent; too small causes oscillation, too large slows adaptation.
Update frequency (every 10 steps) may need tuning for other problems.
Future directions include combining gradient weighting with multi‑scale Fourier feature embeddings, exploring more robust statistics (e.g., quantiles, log‑means), and extending the method to higher‑dimensional PDEs such as Navier‑Stokes.
Resources
https://github.com/xgxgnpu/Physics-informed-vibe-coding
References
[1] S. Wang, Y. Teng, and P. Perdikaris, “Understanding and mitigating gradient flow pathologies in physics‑informed neural networks,” SIAM Journal on Scientific Computing , vol. 43, no. 5, pp. A3055‑A3081, 2021.
[2] M. Raissi, P. Perdikaris, and G. E. Karniadakis, “Physics‑informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations,” Journal of Computational Physics , vol. 378, pp. 686‑707, 2019.
[3] S. Wang, X. Yu, and P. Perdikaris, “When and why PINNs fail to train: A neural tangent kernel perspective,” Journal of Computational Physics , vol. 449, p. 110768, 2022.
[4] X. Glorot and Y. Bengio, “Understanding the difficulty of training deep feedforward neural networks,” in Proceedings of AISTATS , pp. 249‑256, 2010.
[5] D. P. Kingma and J. Ba, “Adam: A method for stochastic optimization,” in Proceedings of ICLR , 2015.
AI Agent Research Hub
Sharing AI, intelligent agents, and cutting-edge scientific computing
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
