Variable-Scaling PINN for 2D Navier‑Stokes: How Coordinate Rescaling Improves Stiff PDE Training

This tutorial explains how a simple coordinate scaling (VS‑PINN) reduces stiffness in physics‑informed neural networks, demonstrates its implementation in JAX for the 2D steady incompressible Navier‑Stokes cylinder‑flow benchmark, and shows that after 80 000 Adam iterations the relative errors drop to 2.10 % (u), 5.06 % (v) and 4.45 % (p).

AI Agent Research Hub
AI Agent Research Hub
AI Agent Research Hub
Variable-Scaling PINN for 2D Navier‑Stokes: How Coordinate Rescaling Improves Stiff PDE Training

Overview

Physics‑informed neural networks (PINNs) often encounter training difficulties on stiff PDEs where convection and diffusion terms differ by several orders of magnitude. The Variable‑Scaling PINN (VS‑PINN) introduced by Ko and Park applies a linear scaling factor N to the input coordinates, which amplifies higher‑order derivative terms and balances the magnitudes of the residual components.

Mathematical Insight

For a scaling factor N, physical coordinates (x, y) are mapped to scaled coordinates (\tilde{x}=N x, \tilde{y}=N y). By the chain rule, a first‑order derivative acquires a factor 1/N and a second‑order derivative a factor 1/N^2. Substituting these relations into the incompressible Navier‑Stokes equations yields residuals in which the viscous term is multiplied by N, effectively reducing the convection‑to‑diffusion ratio.

Network Architecture (JAX)

The model is a fully‑connected multilayer perceptron (MLP) with five hidden layers of 40 tanh units each (total parameters 6 803). Xavier (Glorot uniform) initialization is used to keep the variance of activations stable. Automatic differentiation in JAX provides the required first‑ and second‑order spatial derivatives.

def ns_residual_single(params, x, y):
    u = net_u(params, x, y)
    v = net_v(params, x, y)
    u_x = jax.grad(net_u, argnums=1)(params, x, y)
    u_y = jax.grad(net_u, argnums=2)(params, x, y)
    v_x = jax.grad(net_v, argnums=1)(params, x, y)
    v_y = jax.grad(net_v, argnums=2)(params, x, y)
    u_xx = jax.grad(lambda p, xx, yy: jax.grad(net_u, 1)(p, xx, yy), 1)(params, x, y)
    u_yy = jax.grad(lambda p, xx, yy: jax.grad(net_u, 2)(p, xx, yy), 2)(params, x, y)
    v_xx = jax.grad(lambda p, xx, yy: jax.grad(net_v, 1)(p, xx, yy), 1)(params, x, y)
    v_yy = jax.grad(lambda p, xx, yy: jax.grad(net_v, 2)(p, xx, yy), 2)(params, x, y)
    p_x = jax.grad(net_p, argnums=1)(params, x, y)
    p_y = jax.grad(net_p, argnums=2)(params, x, y)
    N = N_VS  # scaling factor
    r1 = (RHO * (u * N * u_x + v * N * u_y) + N * p_x - MU * (N * N * u_xx + N * N * u_yy)) / N
    r2 = (RHO * (u * N * v_x + v * N * v_y) + N * p_y - MU * (N * N * v_xx + N * N * v_yy)) / N
    r3 = (N * u_x + N * v_y) / N
    return r1, r2, r3

Each residual requires ten jax.grad calls. Evaluating the full collocation set at once would exceed GPU memory, so the implementation uses jax.lax.scan to process points in chunks of 2 000, keeping memory usage low while remaining JIT‑compatible.

CHUNK_SIZE = 2000

def ns_residual_chunked(params, x_arr, y_arr):
    n_chunks = x_arr.shape[0] // CHUNK_SIZE
    x_ch = x_arr.reshape(n_chunks, CHUNK_SIZE)
    y_ch = y_arr.reshape(n_chunks, CHUNK_SIZE)
    def body(_, xy):
        r = ns_residual_vmap(params, xy[0], xy[1])
        return _, r
    _, (r1_all, r2_all, r3_all) = jax.lax.scan(body, None, (x_ch, y_ch))
    return r1_all.reshape(-1), r2_all.reshape(-1), r3_all.reshape(-1)

Boundary‑Condition Sampling

Inlet: 200 uniformly spaced points with a parabolic velocity profile.

Outlet: 200 points on the right boundary.

Top/Bottom walls: 400 points each with no‑slip Dirichlet conditions.

Cylinder surface: 200 points uniformly distributed along the circumference.

Interior: 8 600 uniformly random points plus an additional dense region around the cylinder. All points are resampled each iteration.

Loss Function and Training Loop

def loss_fn(params, xy_col, bnd_xy, bnd_uv, outlet_xy, outlet_p_ref):
    x_col, y_col = xy_col[:,0], xy_col[:,1]
    r1, r2, r3 = ns_residual_chunked(params, x_col, y_col)
    mse_r1 = jnp.mean(r1**2)
    mse_r2 = jnp.mean(r2**2)
    mse_r3 = jnp.mean(r3**2)
    u_bnd, v_bnd, _ = net_uvp_batch(params, bnd_xy[:,0], bnd_xy[:,1])
    mse_bnd_u = jnp.mean((u_bnd - bnd_uv[:,0])**2)
    mse_bnd_v = jnp.mean((v_bnd - bnd_uv[:,1])**2)
    _, _, p_out = net_uvp_batch(params, outlet_xy[:,0], outlet_xy[:,1])
    mse_outlet = jnp.mean((p_out - outlet_p_ref)**2)
    loss_pde = mse_r1 + mse_r2 + mse_r3
    loss_bc = BC_WEIGHT * (mse_bnd_u + mse_bnd_v) + BC_WEIGHT * mse_outlet
    total = loss_pde + loss_bc
    return total, (mse_r1, mse_r2, mse_r3, mse_bnd_u, mse_bnd_v, mse_outlet)

The auxiliary loss components are returned so that jax.value_and_grad(..., has_aux=True) can compute gradients and monitor each term.

Experimental Setup

The benchmark is the Schäfer‑Turek cylinder‑flow problem (Re based on cylinder diameter). Parameters: density = 1.0, viscosity = 0.02, domain ≈ 1.1 m × 0.41 m, cylinder radius = 0.05 m. Training uses Adam with a fixed learning rate for 80 000 iterations, scaling factor N=10, batch size 8 600 interior points plus boundary points, and a boundary‑loss weight of 2.0.

Results and Analysis

Loss curves (log scale) show a rapid drop of the boundary‑condition loss within the first 1 000 iterations, followed by a plateau in the continuity residual around 10 000–15 000 iterations. After the plateau the continuity loss breaks through and all residuals continue decreasing. Final relative errors after 80 000 iterations are:

u‑velocity: 2.10 % (best 1.78 % at iteration 68 100).

v‑velocity: 5.06 % (best 3.75 % at iteration 78 300).

Pressure: 4.45 % (best 3.59 % at iteration 77 400).

Visual comparison with a reference Fluent solution shows accurate velocity fields; the largest absolute errors appear in the low‑speed wake behind the cylinder, while pressure errors concentrate near the front stagnation point.

Scaling‑Factor Study

N=1

(no scaling): strong stiffness, poor convergence. N=5: moderate improvement.

N=10 : best balance, low convection‑to‑diffusion ratio. N=50: over‑smoothing, diffusion dominates.

A short pre‑experiment is recommended to choose N such that the magnitudes of all residual components are comparable.

Limitations

Pressure field shows higher relative error because its value range is small.

A “pseudo‑convergence” phase occurs where loss decreases but physical error remains high; the network first fits boundary conditions.

The scaling factor is a manual hyper‑parameter; no adaptive scheme is provided.

Fixed learning rate slows late‑stage convergence.

Future Directions

Develop an adaptive scaling‑factor selection based on NTK eigen‑spectra or residual ratios.

Combine VS‑PINN with NTK‑based adaptive loss weighting.

Apply L‑BFGS after Adam pre‑training for higher final accuracy.

Integrate Fourier feature embeddings to capture high‑frequency flow structures.

Extend the method to unsteady Navier‑Stokes (time scaling) and three‑dimensional geometries.

References

[1] Ko, S., & Park, S. (2025). VS‑PINN: A fast and efficient training of physics‑informed neural networks using variable‑scaling methods for solving PDEs with stiff behavior. Journal of Computational Physics , 529, 113860.

[2] Raissi, M., Perdikaris, P., & Karniadakis, G. E. (2019). Physics‑informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations. Journal of Computational Physics , 378, 686–707.

[3] Schäfer, M., & Turek, S. (1996). Benchmark computations of laminar flow around a cylinder. In E. H. Hirschel (Ed.), Flow Simulation with High‑Performance Computers II (Notes on Numerical Fluid Mechanics, Vol. 52, pp. 547–566). Vieweg+Teubner Verlag.

[4] Kingma, D. P., & Ba, J. (2015). Adam: A method for stochastic optimization. In Proceedings of the 3rd International Conference on Learning Representations (ICLR 2015) .

[5] Bradbury, J. et al. (2018). JAX: Composable transformations of Python+NumPy programs. URL: http://github.com/jax-ml/jax

JAXNavier-StokesPhysics-Informed Neural NetworksPINNStiff PDEVariable Scaling
AI Agent Research Hub
Written by

AI Agent Research Hub

Sharing AI, intelligent agents, and cutting-edge scientific computing

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.