Step-by-Step Explanation of Neural ODEs with Code Examples
This article introduces Neural Ordinary Differential Equations, explains their core idea of learning continuous dynamics via a neural derivative function, demonstrates Euler integration, compares naive unfolding with the adjoint method for training, provides a PyTorch implementation, and offers practical tips and extensions such as event handling and physics‑informed models.
Background
Continuous phenomena in physics, biology, and engineering are often described by ordinary differential equations (ODEs). When an analytical solution is unavailable or the exact ODE form is unknown, the dynamics can be learned directly from observed data.
Neural ODE formulation
The derivative function \(\frac{dz}{dt}=f(t, z;\theta)\) is parameterized by a neural network \(f\). By integrating this learned derivative—e.g., with the explicit Euler method—one recovers the system trajectory, reproducing the workflow of classical ODE solvers while allowing the model to capture complex, unknown dynamics.
Training methods
Two main gradient‑computation strategies are used:
Unrolled back‑propagation : each integration step is expanded into a computational graph (similar to a recurrent neural network) and gradients are obtained by standard back‑propagation. This approach is straightforward but consumes large amounts of memory for long trajectories.
Adjoint method (originally proposed by Chen et al.): gradients are computed by solving a backward ODE, which drastically reduces memory usage at the expense of additional forward‑backward solves and slower runtime for small networks.
PyTorch example (Euler integration)
class ODEFunc(nn.Module):
def __init__(self, hidden_dim, input_dim):
super(ODEFunc, self).__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ELU(),
nn.Linear(hidden_dim, input_dim)
)
def forward(self, t, x):
return self.net(x)
class ODESolver(nn.Module):
def __init__(self, func):
super(ODESolver, self).__init__()
self.func = func
def forward(self, x0, t):
h = t[1] - t[0] # uniform step size
trajectory = [x0]
x = x0
for i in range(len(t) - 1):
dx = self.func(t[i], x) * h
x = x + dx
trajectory.append(x)
return torch.stack(trajectory)
def main():
n_points = 50
spiral_data = generate_spiral_data(n_points=n_points, noise=0.1)
spiral_data = torch.tensor(spiral_data, dtype=torch.float32)
t = torch.linspace(0, 1, n_points)
func = ODEFunc(hidden_dim=100, input_dim=2)
ode_solver = ODESolver(func)
x0 = spiral_data[0]
optimizer = optim.Adam(func.parameters(), lr=0.01)
n_epochs = 300
for epoch in tqdm(range(n_epochs)):
optimizer.zero_grad()
pred = ode_solver(x0, t)
loss = nn.MSELoss()(pred[:, :, :2].squeeze(), spiral_data)
loss.backward()
optimizer.step()Observed training behavior
On the synthetic spiral dataset the mean‑squared‑error loss exhibits instability and converges slowly, illustrating typical challenges of a vanilla Neural ODE on a toy problem.
Practical considerations
Batching : introduce a mask in the loss function and group trajectories of similar length to improve GPU utilization.
Activation choice : avoid discontinuous activations such as ReLU; smooth functions (e.g., Swish) tend to yield better performance for continuous dynamics.
Curriculum learning : gradually increase the sequence length during training to mitigate gradient explosion or vanishing.
Augmented Neural ODEs : append zero‑valued dimensions to the state vector when the original state lacks sufficient information (e.g., missing velocity), allowing the network to learn useful latent features.
Advanced extensions
Neural event functions : detect predefined discontinuity conditions (e.g., a ball contacting the ground) and trigger state updates at the precisely located event time.
Paper: https://arxiv.org/abs/2011.03902
Neural jumps : extend event functions by learning the timing of state jumps, useful for stochastic or contact‑dynamics scenarios.
Paper: https://arxiv.org/abs/2006.04727
Physics‑informed Neural Networks (PINNs) : embed known physical laws into the loss or the ODE function, providing an inductive bias that enforces conservation or force‑balance constraints.
Paper: https://www.sciencedirect.com/science/article/abs/pii/S0021999118307125
Applications
Neural ODEs excel at irregular‑time‑series prediction and have been applied to flow‑matching in generative image models, where they can outperform diffusion‑based approaches. The technique remains a relatively niche research area, but ongoing work suggests broader adoption in the future.
AI Algorithm Path
A public account focused on deep learning, computer vision, and autonomous driving perception algorithms, covering visual CV, neural networks, pattern recognition, related hardware and software configurations, and open-source projects.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
