Building a Simple Probabilistic Programming Language in Python
This article explains the principles of probabilistic programming languages and walks through constructing a basic PPL in Python, covering model definition with latent and observed variables, distribution handling, DAG traversal for log‑density computation, and demonstrates evaluation with example code and visualizations.
In this article I introduce how Probabilistic Programming Languages (PPLs) work and gradually demonstrate how to build a simple PPL using Python.
The intended audience is statisticians, AI researchers, and curious programmers who are familiar with PPLs, Bayesian statistics, and basic Python.
We will implement the following API:
<code>mu = LatentVariable("mu", Normal, [0.0, 5.0])
y_bar = ObservedVariable("y_bar", Normal, [mu, 1.0], observed=3.0)
evaluate_log_density(y_bar, {"mu": 4.0})</code>The first two lines define a statistical model; the last line evaluates the (unnormalized) probability density of the model at μ = 4 under the observed condition.
Related research
The book The Design and Implementation of Probabilistic Programming Languages focuses on language theory and uses JavaScript for implementation.
The article Anatomy of a Probabilistic Programming Framework provides a high‑level overview but lacks concrete code.
Junpeng Lao’s talk and the PyMC3 developer guide describe PyMC’s implementation details, which are useful but not straightforward to replicate.
Implementation – High‑level representation
We use the following model as a reference:
This expression defines a joint probability distribution whose PDF is shown below:
The model can be visualized as a probabilistic graphical model (PGM) or as a directed factor graph:
Although PGMs are more common in literature, a factor‑graph view (LFG) is more useful for implementing a PPL because it highlights the distinction between observed (gray) and latent (white) variables.
Distribution
In code a distribution is a class that provides a log_density method returning the log‑PDF at a given point. New distributions inherit from an abstract Distribution base class. We implement a Normal distribution using SciPy:
<code>from scipy.stats import norm
class Distribution:
@staticmethod
def log_density(point, params):
raise NotImplementedError("Must be implemented by a subclass")
class Normal(Distribution):
@staticmethod
def log_density(point, params):
return float(norm.logpdf(point, params[0], params[1]))
</code>Variables and DAG
Variables have three aspects: an associated distribution, a type (latent or observed), and connections to child variables, forming a directed acyclic graph (DAG). The dist_class field points to the distribution class, and dist_args holds the parameters, which may be constants or other variables.
<code>class LatentVariable:
def __init__(self, name, dist_class, dist_args):
self.name = name
self.dist_class = dist_class
self.dist_args = dist_args
class ObservedVariable:
def __init__(self, name, dist_class, dist_args, observed):
self.name = name
self.dist_class = dist_class
self.dist_args = dist_args
self.observed = observed
</code>When computing the joint log‑density we traverse the DAG depth‑first, collect all variables, and sum their individual log‑densities. Latent variables obtain their value from a latent_values dictionary, while observed variables use the fixed observed value.
<code>def evaluate_log_density(variable, latent_values):
visited = set()
variables = []
def collect_variables(v):
if isinstance(v, float):
return
visited.add(v)
variables.append(v)
for arg in v.dist_args:
if arg not in visited:
collect_variables(arg)
collect_variables(variable)
log_density = 0.0
for var in variables:
dist_params = []
for arg in var.dist_args:
if isinstance(arg, float):
dist_params.append(arg)
if isinstance(arg, LatentVariable):
dist_params.append(latent_values[arg.name])
if isinstance(var, LatentVariable):
log_density += var.dist_class.log_density(latent_values[var.name], dist_params)
if isinstance(var, ObservedVariable):
log_density += var.dist_class.log_density(var.observed, dist_params)
return log_density
</code>Testing the implementation:
<code>mu = LatentVariable("mu", Normal, [0.0, 5.0])
y_bar = ObservedVariable("y_bar", Normal, [mu, 1.0], observed=5.0)
latent_values = {"mu": 4.0}
print(evaluate_log_density(y_bar, latent_values)) # => -4.267314978843446
print(
norm.logpdf(4.0, 0.0, 5.0) +
norm.logpdf(5.0, 4.0, 1.0)
) # => -4.267314978843446
</code>Conclusion and future work
Distribution handling, variable DAGs, and log‑density computation are core components of a PPL. By implementing these concepts in Python we obtain a simple yet powerful probabilistic programming language.
Future extensions include support for tensors and variable transformations to enable models such as linear regression and hierarchical models, an API for prior predictive sampling, and integration with computational‑graph frameworks like Theano/Aesara, JAX, or TensorFlow for automatic differentiation and advanced samplers such as Hamiltonian Monte Carlo.
Side note – Posterior grid approximation
The log‑density can be used to locate the mode of the posterior distribution. In a simple example with a weakly informative zero‑mean prior and observed sample mean 1.5, the MAP estimate is around 1.4.
<code>import numpy as np
import pandas as pd
import altair as alt
from smolppl import Normal, LatentVariable, ObservedVariable, evaluate_log_density
mu = LatentVariable("mu", Normal, [0.0, 5.0])
y_bar = ObservedVariable("y_bar", Normal, [mu, 1.0], observed=1.5)
grid = np.linspace(-4, 4, 20)
evals = [evaluate_log_density(y_bar, {"mu": m}) for m in grid]
data = pd.DataFrame({"grid": grid, "logdensity": evals})
chart = alt.Chart(data).mark_line(point=True).encode(
x=alt.X('grid', title='mu'),
y=alt.Y('logdensity', title='log density')
).interactive()
chart
</code>Python Programming Learning Circle
A global community of Chinese Python developers offering technical articles, columns, original video tutorials, and problem sets. Topics include web full‑stack development, web scraping, data analysis, natural language processing, image processing, machine learning, automated testing, DevOps automation, and big data.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.