Artificial Intelligence 11 min read

An Introduction to JAX: Features, Installation, and Comparison with TensorFlow and PyTorch

This article introduces Google’s JAX library, covering its origins, core features such as automatic differentiation, JIT compilation, parallel and vectorized mapping, installation steps, code examples, and a comparative overview with TensorFlow and PyTorch for deep‑learning practitioners.

Python Programming Learning Circle
Python Programming Learning Circle
Python Programming Learning Circle
An Introduction to JAX: Features, Installation, and Comparison with TensorFlow and PyTorch

JAX is a rapidly emerging library in the machine learning field, aiming to make ML programming more intuitive, structured, and concise. While TensorFlow and PyTorch dominate the landscape, JAX—originating from Google Brain researchers Matt Johnson, Roy Frostig, Dougal Maclaurin, and Chris Leary—has attracted significant attention, accumulating 13.7K stars on GitHub.

Built on the predecessor Autograd and tightly integrated with XLA, JAX provides automatic differentiation for Python and NumPy code, supporting loops, branches, recursion, closures, and even third‑order derivatives. Leveraging XLA, it can compile and run NumPy programs on GPUs and TPUs, and its grad function enables flexible forward and reverse mode differentiation.

The motivation behind JAX is to combine NumPy’s ease of use and stable API with hardware acceleration, addressing NumPy’s lack of GPU support, built‑in back‑propagation, and Python’s speed limitations. Numerous open‑source projects now build on JAX, such as Google’s Haiku (object‑oriented deep‑learning library), RLax (reinforcement‑learning library), and JAXnet (GPU‑accelerated deep‑learning library).

JAX Installation

To use JAX, install it via pip in a Python environment or Google Colab. For CPU‑only:

<code>$ pip install --upgrade jax jaxlib</code>

For GPU support, ensure CUDA and cuDNN are installed and run:

<code>$ pip install --upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html</code>

Import JAX alongside NumPy:

<code>import jax
import jax.numpy as jnp
import numpy as np</code>

Key JAX Features

Automatic Differentiation (grad) : Example of differentiating a quadratic function at point 1.0.

<code>from jax import grad

def f(x):
    return 3*x**2 + 2*x + 5

def f_prime(x):
    return 6*x + 2

grad(f)(1.0)   # DeviceArray(8., dtype=float32)
print(f_prime(1.0))   # 8.0</code>

Just‑In‑Time Compilation (jit) : Compile code to XLA kernels for performance gains.

<code>from jax import jit
x = np.random.rand(1000, 1000)
y = jnp.array(x)

def f(x):
    for _ in range(10):
        x = 0.5*x + 0.1*jnp.sin(x)
    return x

g = jit(f)
%timeit -n 5 -r 5 f(y).block_until_ready()   # ~10.8 ms per loop
%timeit -n 5 -r 5 g(y).block_until_ready()   # ~341 µs per loop</code>

Parallel Mapping (pmap) : Distribute computation across all available devices.

<code>from jax import pmap

def f(x):
    return jnp.sin(x) + x**2

f(np.arange(4))
# DeviceArray([0., 1.841471, 4.9092975, 9.14112], dtype=float32)

pmap(f)(np.arange(4))
# ShardedDeviceArray([...], dtype=float32)</code>

Vectorized Mapping (vmap) : Automatic vectorization of functions.

<code>from jax import vmap

def f(x):
    return jnp.square(x)

f(jnp.arange(10))
# DeviceArray([0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)

vmap(f)(jnp.arange(10))
# Same result, vectorized</code>

TensorFlow vs PyTorch vs JAX

All three frameworks are open‑source and rely on C/C++ back‑ends to overcome Python’s Global Interpreter Lock. They differ in execution models, APIs, and ecosystem focus. The article includes a comparative table (images) highlighting these differences.

TensorFlow (Google, 2015) offers a mature ecosystem, Keras integration, eager execution, and TensorBoard visualizations, though it has faced criticism for API stability and static graph complexity.

PyTorch (Meta) provides dynamic computation graphs, extensive low‑level control, strong GPU support, and a Pythonic debugging experience, making it popular for research.

JAX (Google) emphasizes composable function transformations—automatic differentiation, JIT compilation, parallel and vectorized mapping—while preserving NumPy‑like syntax. It serves as a foundation for higher‑level libraries such as Haiku, Flax, ObJax, and Elegy.

Overall, JAX combines NumPy’s simplicity with hardware acceleration and advanced transformation capabilities, positioning it as a compelling alternative for modern deep‑learning research.

Machine Learningpythondeep learningTensorFlowGPUPyTorchAutomatic DifferentiationJAX
Python Programming Learning Circle
Written by

Python Programming Learning Circle

A global community of Chinese Python developers offering technical articles, columns, original video tutorials, and problem sets. Topics include web full‑stack development, web scraping, data analysis, natural language processing, image processing, machine learning, automated testing, DevOps automation, and big data.

0 followers
Reader feedback

How this landed with the community

login Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.