Building a Flow Matching Model from Scratch: Complete Code Walkthrough

This article walks through the full implementation of a flow‑matching generative model in PyTorch, covering dataset creation, a small MLP that learns a time‑dependent velocity field, the flow‑matching loss, training loop, ODE‑based sampling, visualisation of the learned vector field, and a discussion of the method's limitations and possible extensions.

AI Algorithm Path
AI Algorithm Path
AI Algorithm Path
Building a Flow Matching Model from Scratch: Complete Code Walkthrough

In a previous post we introduced the theory of flow matching, highlighted its differences from diffusion models, and discussed its advantages. This article deepens that understanding by providing a concrete PyTorch implementation.

Dataset Construction

Two simple 2‑D distributions are used as source and target:

p₀: a standard 2‑D Gaussian (mean 0, std 1).

p₁: a checkerboard toy dataset composed of Gaussian clusters arranged in a grid.

def sample_source(batch_size):
    # Sample from a 2D standard Gaussian (mean=0, std=1)
    return torch.randn(batch_size, 2)

def sample_target(batch_size):
    # Uniform x‑coordinate in [-2, 2)
    x1 = torch.rand(batch_size) * 4 - 2
    # Create y‑coordinate with alternating rows
    x2_ = torch.rand(batch_size) - torch.randint(high=2, size=(batch_size,)) * 2
    x2 = x2_ + (torch.floor(x1) % 2)
    data = 1.0 * torch.cat([x1[:, None], x2[:, None]], dim=1) / 0.45
    return torch.tensor(data, dtype=torch.float32)

Visualising the sampled points shows the Gaussian cloud for p₀ and the checkerboard pattern for p₁.

Model Definition

A small multilayer perceptron learns the time‑dependent velocity field f(x, t). The scalar time t is first embedded into a higher‑dimensional space, concatenated with the 2‑D position x, and passed through several SiLU‑activated fully‑connected layers.

class FlowModel(nn.Module):
    def __init__(self, input_dim=2, time_embed_dim=64):
        super().__init__()
        # Embed time scalar
        self.time_embed = nn.Sequential(
            nn.Linear(1, time_embed_dim),
            nn.SiLU(),
            nn.Linear(time_embed_dim, time_embed_dim)
        )
        # Main network predicts velocity
        self.net = nn.Sequential(
            nn.Linear(input_dim + time_embed_dim, 128), nn.SiLU(),
            nn.Linear(128, 128), nn.SiLU(),
            nn.Linear(128, 128), nn.SiLU(),
            nn.Linear(128, 128), nn.SiLU(),
            nn.Linear(128, 128), nn.SiLU(),
            nn.Linear(128, input_dim)  # output velocity
        )
    def forward(self, x, t):
        t_embed = self.time_embed(t)
        xt = torch.cat([x, t_embed], dim=-1)
        return self.net(xt)

Flow‑Matching Loss

The ground‑truth velocity is simply x₁ − x₀ (constant along the straight‑line interpolation). The loss is the mean squared error between the predicted velocity and this ground truth.

def flow_matching_loss(model, x0, x1, t):
    xt = (1 - t) * x0 + t * x1               # interpolated point
    v_target = x1 - x0                       # constant velocity
    v_pred = model(xt, t)                    # model prediction
    return ((v_pred - v_target) ** 2).mean()

Training Procedure

For each training step we sample a batch from p₀ and p₁, draw a random time t∈[0, 1], compute the loss, and update the model with Adam (lr = 5e‑4). The loss drops quickly in early iterations and then fluctuates.

num_steps = 10000
batch_size = 512
model = FlowModel().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
for step in tqdm(range(num_steps)):
    x0 = sample_source(batch_size).to(device)
    x1 = sample_target(batch_size).to(device)
    t = torch.rand(batch_size, 1).to(device)  # random interpolation time
    loss = flow_matching_loss(model, x0, x1, t)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if step % 100 == 0:
        print(f"Step {step} | Loss: {loss.item():.4f}")

Sampling from the Trained Model

To generate new samples we start from a point drawn from the source distribution and integrate the learned velocity field from t = 0 to t = 1 using scipy.integrate.solve_ivp. The resulting points lie in the target checkerboard distribution.

def sample_flow(model, x0, t_span=(0, 1)):
    """Evolve x0 through the learned flow to produce a sample from p₁."""
    def ode_func(t, x):
        x_tensor = torch.tensor(x, dtype=torch.float32).unsqueeze(0).to(device)
        t_tensor = torch.tensor([[t]], dtype=torch.float32).to(device)
        with torch.no_grad():
            v = model(x_tensor, t_tensor)
        return v.squeeze(0).cpu().numpy()
    sol = solve_ivp(ode_func, t_span, x0.cpu().numpy(), t_eval=[t_span[1]])
    return sol.y[:, -1]

Qualitative results show generated samples matching the checkerboard shape.

Visualising the Velocity Field

The learned vector field evolves smoothly from an outward‑radial flow at t = 0 (moving toward the Gaussian) to a pattern that aligns with the checkerboard at t = 1.

def plot_velocity_row(model, t_values=[0.0, 0.25, 0.5, 0.75, 1.0], grid_size=20):
    model.eval()
    with torch.no_grad():
        x = np.linspace(-4, 4, grid_size)
        y = np.linspace(-4, 4, grid_size)
        xx, yy = np.meshgrid(x, y)
        xy = np.stack([xx.flatten(), yy.flatten()], axis=1)
        xt = torch.tensor(xy, dtype=torch.float32).to(device)
        fig, axes = plt.subplots(1, len(t_values), figsize=(4*len(t_values), 4))
        for i, t_val in enumerate(t_values):
            tt = torch.full((xt.shape[0], 1), t_val, dtype=torch.float32).to(device)
            v = model(xt, tt).cpu().numpy()
            ax = axes[i]
            ax.quiver(xx, yy, v[:,0].reshape(grid_size, grid_size), v[:,1].reshape(grid_size, grid_size), scale=20)
            ax.set_title(f"t = {t_val}")
            ax.axis('equal')
            ax.grid(True)
        plt.tight_layout()
        plt.show()

Discussion

Limitations

Sampling has no theoretical guarantee of matching the target distribution exactly.

Inference requires solving an ODE for each sample, which is computationally heavier than a single forward pass.

The method relies on access to ground‑truth velocity fields, which is easy for synthetic data but challenging for real‑world datasets.

Potential Extensions

Combine flow matching with score‑based models to leverage both objectives.

Employ advanced neural ODE solvers or learned solvers to reduce inference cost.

Conclusion

We have built a complete flow‑matching model from scratch, implemented the data pipeline, the neural velocity field, the training loop, ODE‑based sampling, and visualised the learned dynamics. The code repository is available at https://github.com/vickiiimu/checkerboard-FM-tutorial.git.

flow matchinggenerative modelsPyTorchMLPneural ODE
AI Algorithm Path
Written by

AI Algorithm Path

A public account focused on deep learning, computer vision, and autonomous driving perception algorithms, covering visual CV, neural networks, pattern recognition, related hardware and software configurations, and open-source projects.

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.