Solving the Burgers Equation with TINN: High‑Precision Physics‑Informed Neural Networks in 380 seconds

This tutorial presents the Time‑Induced Neural Network (TINN) framework that overcomes the time‑entanglement issue of standard PINNs by introducing a dedicated time‑subnet with FiLM modulation, employs a Levenberg‑Marquardt optimizer for second‑order updates, and demonstrates a 1e‑6 relative error solution of the 1‑D viscous Burgers equation in just 371 seconds on an RTX 4090.

AI Agent Research Hub
AI Agent Research Hub
AI Agent Research Hub
Solving the Burgers Equation with TINN: High‑Precision Physics‑Informed Neural Networks in 380 seconds

Abstract

Standard physics‑informed neural networks (PINNs) treat time as an additional input, forcing a single set of network parameters to represent the solution over the entire temporal domain. This creates a “time‑entanglement” problem that limits both accuracy and convergence speed for time‑dependent PDEs such as the one‑dimensional viscous Burgers equation.

Time‑Induced Neural Networks (TINN)

TINN introduces an independent time sub‑network that maps a scalar time value to a 5‑dimensional modulation vector. The modulation vector is applied to every layer of the spatial network via a Feature‑wise Linear Modulation (FiLM) mechanism, allowing the effective representation to evolve with time while keeping the overall parameter count low.

FiLM modulation

The spatial network’s weight matrices remain time‑invariant; each FiLM layer scales and shifts the linear transformation using the time‑subnet output. The scaling matrix is zero‑initialized, so the network initially behaves like a standard MLP and gradually learns non‑zero modulation coefficients.

Time sub‑network

The time sub‑network consists of two linear layers and a learnable skip connection that defaults to the identity mapping, ensuring stable early training.

Parameter count

Time sub‑network (linear layers + skip): 185 parameters

Spatial main network (two hidden layers + output): 980 parameters

Total : 1,145 parameters (≈ 0.05 GB of Jacobian memory in float32)

Levenberg‑Marquardt Optimizer

Training is cast as a nonlinear least‑squares problem. The full Jacobian of all residuals with respect to the parameters is assembled using JAX’s reverse‑mode jacrev and vectorized with vmap. The LM update solves the damped normal equations (JᵀJ + μI)Δθ = -Jᵀr and adapts the damping parameter μ every two steps based on loss reduction.

Complexity comparison

Back‑prop passes per step: Adam = 1, LM ≈ 2 (Jacobian + JᵀJ)

Memory for Jacobian: Adam = none, LM ≈ 49 MB (float32)

Total training time (5,000 LM steps): 371.1 s

Final relative L2 error: ≈ 1 × 10⁻⁶

Although each LM step is more expensive, the high‑quality second‑order updates achieve the target accuracy in far fewer steps than first‑order methods.

Forward‑mode JVP for Derivatives

JAX provides forward‑mode (JVP) and reverse‑mode (VJP) automatic‑differentiation primitives. For scalar PDE residuals, forward‑mode JVP computes directional derivatives with a single forward pass, avoiding the storage of full computational graphs. Nested JVP calls yield exact second‑order derivatives needed for the Burgers residual.

@jax.jit
def grads_jvp(params, xt_batch):
    N = xt_batch.shape[0]
    v_x = jnp.tile(jnp.array([1.0, 0.0]), (N, 1))   # direction for x
    v_t = jnp.tile(jnp.array([0.0, 1.0]), (N, 1))   # direction for t
    _, u_x = jvp(lambda xs: u_batch(params, xs), (xt_batch,), (v_x,))
    _, u_t = jvp(lambda xs: u_batch(params, xs), (xt_batch,), (v_t,))
    return u_x, u_t

Experimental Setup

Problem: 1‑D viscous Burgers equation (shock‑forming regime)

Collocation points: 10,000 (uniform)

Initial points: 500

Boundary points: 200 per side (total 400)

Validation points: 5,250 (triggering resampling)

Optimizer: Levenberg‑Marquardt, 5,000 steps

Hardware: NVIDIA RTX 4090, float32

Results

Total training time: 371.1 s (≈ 6.2 min)

Network parameters: 1,145

Final relative L2 error: ≈ 1 × 10⁻⁶

Convergence analysis

Rapid initial drop: first 100 steps reduce error by ~2.5 orders of magnitude.

Steady mid‑phase decline (steps 100–1,000) with monotonic loss reduction.

Resampling events (triggered by validation) mitigate over‑fitting of collocation points.

After 5,000 steps the error has decreased by ~5.7 orders of magnitude, confirming the efficiency of LM updates.

Visual comparison

Reference solution vs. TINN prediction vs. absolute error
Reference solution vs. TINN prediction vs. absolute error

The TINN prediction captures the shock front accurately; absolute error is confined to a narrow region ahead of the shock.

Limitations

Parameter scalability: Jacobian memory grows as O(P²); large‑scale 3‑D problems become prohibitive.

Computational cost: Each LM step requires full Jacobian assembly, limiting applicability to high‑dimensional PDEs.

Parallelism: LM updates are inherently sequential, reducing GPU pipeline utilization.

High‑dimensional extension: FiLM modulation vector is fixed at dimension 5; effectiveness for multi‑dimensional spatial domains remains untested.

Future Directions

Sketch‑based LM: random sub‑sampling of residuals to build a sketched Jacobian, lowering per‑step cost.

Multi‑step time marching: divide the temporal domain into sub‑intervals and fine‑tune TINN locally.

Low‑rank FiLM: factorize modulation vectors to reduce extra parameters.

Integration with adaptive sampling (RVAAS, RAR) to concentrate points near shocks.

Mixed‑precision training: compute Jacobians in float16 while solving the normal equations in float32 to save memory.

References

Dai, C.-Y., Chang, C.-C., Lin, T.-S., Lai, M.-C., & Lai, C.-H. (2026). TINNs: Time‑Induced Neural Networks for Solving Time‑Dependent PDEs . arXiv:2601.20361. https://arxiv.org/abs/2601.20361

Raissi, M., Perdikaris, P., & Karniadakis, G. E. (2019). Physics‑informed neural networks: A deep learning framework for solving forward and inverse problems involving nonlinear partial differential equations. Journal of Computational Physics , 378, 686–707. https://doi.org/10.1016/j.jcp.2018.10.045

Marquardt, D. W. (1963). An algorithm for least‑squares estimation of nonlinear parameters. SIAM Journal on Applied Mathematics , 11(2), 431–441. https://doi.org/10.1137/0111030

Wang, S., Yu, X., & Perdikaris, P. (2022). When and why PINNs fail to train: A neural tangent kernel perspective. Journal of Computational Physics , 449, 110768. https://doi.org/10.1016/j.jcp.2021.110768

Bradbury, J. et al. (2018). JAX: composable transformations of Python+NumPy programs (Version 0.3.13) [Software]. https://github.com/jax‑ml/jax

Code Repository

https://github.com/xgxgnpu/Physics-informed-vibe-coding

JAXPhysics-Informed Neural NetworksBurgers EquationFiLM ModulationLevenberg-MarquardtTime SubnetworkTINN
AI Agent Research Hub
Written by

AI Agent Research Hub

Sharing AI, intelligent agents, and cutting-edge scientific computing

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.