An Introduction to JAX: Features, Installation, and Comparison with TensorFlow and PyTorch
This article introduces Google’s JAX library, covering its origins, core features such as automatic differentiation, JIT compilation, parallel and vectorized mapping, installation steps, code examples, and a comparative overview with TensorFlow and PyTorch for deep‑learning practitioners.
JAX is a rapidly emerging library in the machine learning field, aiming to make ML programming more intuitive, structured, and concise. While TensorFlow and PyTorch dominate the landscape, JAX—originating from Google Brain researchers Matt Johnson, Roy Frostig, Dougal Maclaurin, and Chris Leary—has attracted significant attention, accumulating 13.7K stars on GitHub.
Built on the predecessor Autograd and tightly integrated with XLA, JAX provides automatic differentiation for Python and NumPy code, supporting loops, branches, recursion, closures, and even third‑order derivatives. Leveraging XLA, it can compile and run NumPy programs on GPUs and TPUs, and its grad function enables flexible forward and reverse mode differentiation.
The motivation behind JAX is to combine NumPy’s ease of use and stable API with hardware acceleration, addressing NumPy’s lack of GPU support, built‑in back‑propagation, and Python’s speed limitations. Numerous open‑source projects now build on JAX, such as Google’s Haiku (object‑oriented deep‑learning library), RLax (reinforcement‑learning library), and JAXnet (GPU‑accelerated deep‑learning library).
JAX Installation
To use JAX, install it via pip in a Python environment or Google Colab. For CPU‑only:
<code>$ pip install --upgrade jax jaxlib</code>For GPU support, ensure CUDA and cuDNN are installed and run:
<code>$ pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html</code>Import JAX alongside NumPy:
<code>import jax
import jax.numpy as jnp
import numpy as np</code>Key JAX Features
Automatic Differentiation (grad) : Example of differentiating a quadratic function at point 1.0.
<code>from jax import grad
def f(x):
return 3*x**2 + 2*x + 5
def f_prime(x):
return 6*x + 2
grad(f)(1.0) # DeviceArray(8., dtype=float32)
print(f_prime(1.0)) # 8.0</code>Just‑In‑Time Compilation (jit) : Compile code to XLA kernels for performance gains.
<code>from jax import jit
x = np.random.rand(1000, 1000)
y = jnp.array(x)
def f(x):
for _ in range(10):
x = 0.5*x + 0.1*jnp.sin(x)
return x
g = jit(f)
%timeit -n 5 -r 5 f(y).block_until_ready() # ~10.8 ms per loop
%timeit -n 5 -r 5 g(y).block_until_ready() # ~341 µs per loop</code>Parallel Mapping (pmap) : Distribute computation across all available devices.
<code>from jax import pmap
def f(x):
return jnp.sin(x) + x**2
f(np.arange(4))
# DeviceArray([0., 1.841471, 4.9092975, 9.14112], dtype=float32)
pmap(f)(np.arange(4))
# ShardedDeviceArray([...], dtype=float32)</code>Vectorized Mapping (vmap) : Automatic vectorization of functions.
<code>from jax import vmap
def f(x):
return jnp.square(x)
f(jnp.arange(10))
# DeviceArray([0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)
vmap(f)(jnp.arange(10))
# Same result, vectorized</code>TensorFlow vs PyTorch vs JAX
All three frameworks are open‑source and rely on C/C++ back‑ends to overcome Python’s Global Interpreter Lock. They differ in execution models, APIs, and ecosystem focus. The article includes a comparative table (images) highlighting these differences.
TensorFlow (Google, 2015) offers a mature ecosystem, Keras integration, eager execution, and TensorBoard visualizations, though it has faced criticism for API stability and static graph complexity.
PyTorch (Meta) provides dynamic computation graphs, extensive low‑level control, strong GPU support, and a Pythonic debugging experience, making it popular for research.
JAX (Google) emphasizes composable function transformations—automatic differentiation, JIT compilation, parallel and vectorized mapping—while preserving NumPy‑like syntax. It serves as a foundation for higher‑level libraries such as Haiku, Flax, ObJax, and Elegy.
Overall, JAX combines NumPy’s simplicity with hardware acceleration and advanced transformation capabilities, positioning it as a compelling alternative for modern deep‑learning research.
Python Programming Learning Circle
A global community of Chinese Python developers offering technical articles, columns, original video tutorials, and problem sets. Topics include web full‑stack development, web scraping, data analysis, natural language processing, image processing, machine learning, automated testing, DevOps automation, and big data.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.