Implementing a Simple Probabilistic Programming Language in Python
This article explains the principles of probabilistic programming languages and walks through a step‑by‑step implementation of a minimal PPL in Python, covering model definition, variable representation, DAG traversal, log‑density computation, and a posterior grid illustration.
In this article we introduce Probabilistic Programming Languages (PPLs), describe how they work, and demonstrate a compact Python implementation that can define a Bayesian model, represent latent and observed variables, and evaluate the joint log‑density.
The target audience includes statisticians, AI researchers, and curious programmers familiar with Bayesian inference and basic Python. The example API is shown below:
mu = LatentVariable("mu", Normal, [0.0, 5.0])
y_bar = ObservedVariable("y_bar", Normal, [mu, 1.0], observed=3.0)
evaluate_log_density(y_bar, {"mu": 4.0})
Two graphical representations of the model are provided: a probabilistic graphical model (PGM) and a directed factor graph, highlighting the distinction between observed (gray) and latent (white) variables.
We define a Distribution abstract class using SciPy’s norm for a Normal distribution:
from scipy.stats import norm
class Distribution:
@staticmethod
def log_density(point, params):
raise NotImplementedError("Must be implemented by a subclass")
class Normal(Distribution):
@staticmethod
def log_density(point, params):
return float(norm.logpdf(point, params[0], params[1]))
Variable classes capture the essential information for each node in the DAG:
class LatentVariable:
def __init__(self, name, dist_class, dist_args):
self.name = name
self.dist_class = dist_class
self.dist_args = dist_args
class ObservedVariable:
def __init__(self, name, dist_class, dist_args, observed):
self.name = name
self.dist_class = dist_class
self.dist_args = dist_args
self.observed = observed
To compute the joint log‑density we traverse the DAG with a depth‑first search, collect all non‑float variables, and sum the log‑densities of each node. The core evaluation routine is:
def evaluate_log_density(variable, latent_values):
visited = set()
variables = []
def collect_variables(v):
if isinstance(v, float): return
visited.add(v) variables.append(v) for arg in v.dist_args: if arg not in visited: collect_variables(arg)
collect_variables(variable)
After gathering the variables we extract concrete parameter values (floats or values from latent_values ) and accumulate the log‑density for latent and observed variables separately.
The final result matches the manual calculation using SciPy’s norm.logpdf :
mu = LatentVariable("mu", Normal, [0.0, 5.0])
y_bar = ObservedVariable("y_bar", Normal, [mu, 1.0], observed=5.0)
latent_values = {"mu": 4.0}
print(evaluate_log_density(y_bar, latent_values)) # => -4.267314978843446
The article concludes with possible extensions such as supporting tensors, linear regression models, prior‑predictive sampling, and integrating with computational‑graph frameworks like Theano, JAX, or TensorFlow to enable automatic differentiation for advanced samplers.
A posterior grid approximation example demonstrates how the log‑density can be evaluated over a range of mu values and visualized with Altair:
import numpy as np import pandas as pd import altair as alt from smolppl import Normal, LatentVariable, ObservedVariable, evaluate_log_density mu = LatentVariable("mu", Normal, [0.0, 5.0]) y_bar = ObservedVariable("y_bar", Normal, [mu, 1.0], observed=1.5) grid = np.linspace(-4, 4, 20) evaluations = [evaluate_log_density(y_bar, {"mu": m}) for m in grid] data = pd.DataFrame({"grid": grid, "evaluations": evaluations}) chart = alt.Chart(data).mark_line(point=True).encode( x=alt.X('grid', title='mu'), y=alt.Y('evaluations', title='logdensity') ).interactive() chart
References to related work, including a book on PPL design, a survey article, and PyMC3 developer guides, are provided for readers who wish to explore deeper implementations.
Python Programming Learning Circle
A global community of Chinese Python developers offering technical articles, columns, original video tutorials, and problem sets. Topics include web full‑stack development, web scraping, data analysis, natural language processing, image processing, machine learning, automated testing, DevOps automation, and big data.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.