Building a Flow Matching Model from Scratch: Complete Code Walkthrough
This article walks through the full implementation of a flow‑matching generative model in PyTorch, covering dataset creation, a small MLP that learns a time‑dependent velocity field, the flow‑matching loss, training loop, ODE‑based sampling, visualisation of the learned vector field, and a discussion of the method's limitations and possible extensions.
In a previous post we introduced the theory of flow matching, highlighted its differences from diffusion models, and discussed its advantages. This article deepens that understanding by providing a concrete PyTorch implementation.
Dataset Construction
Two simple 2‑D distributions are used as source and target:
p₀: a standard 2‑D Gaussian (mean 0, std 1).
p₁: a checkerboard toy dataset composed of Gaussian clusters arranged in a grid.
def sample_source(batch_size):
# Sample from a 2D standard Gaussian (mean=0, std=1)
return torch.randn(batch_size, 2)
def sample_target(batch_size):
# Uniform x‑coordinate in [-2, 2)
x1 = torch.rand(batch_size) * 4 - 2
# Create y‑coordinate with alternating rows
x2_ = torch.rand(batch_size) - torch.randint(high=2, size=(batch_size,)) * 2
x2 = x2_ + (torch.floor(x1) % 2)
data = 1.0 * torch.cat([x1[:, None], x2[:, None]], dim=1) / 0.45
return torch.tensor(data, dtype=torch.float32)Visualising the sampled points shows the Gaussian cloud for p₀ and the checkerboard pattern for p₁.
Model Definition
A small multilayer perceptron learns the time‑dependent velocity field f(x, t). The scalar time t is first embedded into a higher‑dimensional space, concatenated with the 2‑D position x, and passed through several SiLU‑activated fully‑connected layers.
class FlowModel(nn.Module):
def __init__(self, input_dim=2, time_embed_dim=64):
super().__init__()
# Embed time scalar
self.time_embed = nn.Sequential(
nn.Linear(1, time_embed_dim),
nn.SiLU(),
nn.Linear(time_embed_dim, time_embed_dim)
)
# Main network predicts velocity
self.net = nn.Sequential(
nn.Linear(input_dim + time_embed_dim, 128), nn.SiLU(),
nn.Linear(128, 128), nn.SiLU(),
nn.Linear(128, 128), nn.SiLU(),
nn.Linear(128, 128), nn.SiLU(),
nn.Linear(128, 128), nn.SiLU(),
nn.Linear(128, input_dim) # output velocity
)
def forward(self, x, t):
t_embed = self.time_embed(t)
xt = torch.cat([x, t_embed], dim=-1)
return self.net(xt)Flow‑Matching Loss
The ground‑truth velocity is simply x₁ − x₀ (constant along the straight‑line interpolation). The loss is the mean squared error between the predicted velocity and this ground truth.
def flow_matching_loss(model, x0, x1, t):
xt = (1 - t) * x0 + t * x1 # interpolated point
v_target = x1 - x0 # constant velocity
v_pred = model(xt, t) # model prediction
return ((v_pred - v_target) ** 2).mean()Training Procedure
For each training step we sample a batch from p₀ and p₁, draw a random time t∈[0, 1], compute the loss, and update the model with Adam (lr = 5e‑4). The loss drops quickly in early iterations and then fluctuates.
num_steps = 10000
batch_size = 512
model = FlowModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
for step in tqdm(range(num_steps)):
x0 = sample_source(batch_size).to(device)
x1 = sample_target(batch_size).to(device)
t = torch.rand(batch_size, 1).to(device) # random interpolation time
loss = flow_matching_loss(model, x0, x1, t)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 100 == 0:
print(f"Step {step} | Loss: {loss.item():.4f}")Sampling from the Trained Model
To generate new samples we start from a point drawn from the source distribution and integrate the learned velocity field from t = 0 to t = 1 using scipy.integrate.solve_ivp. The resulting points lie in the target checkerboard distribution.
def sample_flow(model, x0, t_span=(0, 1)):
"""Evolve x0 through the learned flow to produce a sample from p₁."""
def ode_func(t, x):
x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0).to(device)
t_tensor = torch.tensor([[t]], dtype=torch.float32).to(device)
with torch.no_grad():
v = model(x_tensor, t_tensor)
return v.squeeze(0).cpu().numpy()
sol = solve_ivp(ode_func, t_span, x0.cpu().numpy(), t_eval=[t_span[1]])
return sol.y[:, -1]Qualitative results show generated samples matching the checkerboard shape.
Visualising the Velocity Field
The learned vector field evolves smoothly from an outward‑radial flow at t = 0 (moving toward the Gaussian) to a pattern that aligns with the checkerboard at t = 1.
def plot_velocity_row(model, t_values=[0.0, 0.25, 0.5, 0.75, 1.0], grid_size=20):
model.eval()
with torch.no_grad():
x = np.linspace(-4, 4, grid_size)
y = np.linspace(-4, 4, grid_size)
xx, yy = np.meshgrid(x, y)
xy = np.stack([xx.flatten(), yy.flatten()], axis=1)
xt = torch.tensor(xy, dtype=torch.float32).to(device)
fig, axes = plt.subplots(1, len(t_values), figsize=(4*len(t_values), 4))
for i, t_val in enumerate(t_values):
tt = torch.full((xt.shape[0], 1), t_val, dtype=torch.float32).to(device)
v = model(xt, tt).cpu().numpy()
ax = axes[i]
ax.quiver(xx, yy, v[:,0].reshape(grid_size, grid_size), v[:,1].reshape(grid_size, grid_size), scale=20)
ax.set_title(f"t = {t_val}")
ax.axis('equal')
ax.grid(True)
plt.tight_layout()
plt.show()Discussion
Limitations
Sampling has no theoretical guarantee of matching the target distribution exactly.
Inference requires solving an ODE for each sample, which is computationally heavier than a single forward pass.
The method relies on access to ground‑truth velocity fields, which is easy for synthetic data but challenging for real‑world datasets.
Potential Extensions
Combine flow matching with score‑based models to leverage both objectives.
Employ advanced neural ODE solvers or learned solvers to reduce inference cost.
Conclusion
We have built a complete flow‑matching model from scratch, implemented the data pipeline, the neural velocity field, the training loop, ODE‑based sampling, and visualised the learned dynamics. The code repository is available at https://github.com/vickiiimu/checkerboard-FM-tutorial.git.
AI Algorithm Path
A public account focused on deep learning, computer vision, and autonomous driving perception algorithms, covering visual CV, neural networks, pattern recognition, related hardware and software configurations, and open-source projects.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
