Unlock Native TPU Inference with SGLang-Jax: A Jax‑Powered Open‑Source Engine
SGLang-Jax is a cutting‑edge, fully Jax‑based open‑source inference engine that delivers native TPU performance, integrates advanced features like continuous batching, tensor and expert parallelism, and speculative decoding, while providing detailed installation and usage guidance for developers.
SGLang-Jax: An Open‑Source Native TPU Inference Solution
On October 19, the SGLang team announced SGLang-Jax, a new open‑source inference engine built entirely on Jax and XLA, offering efficient native TPU inference. The project, co‑developed with Ant Group's InclusionAI team, supports Ant's proprietary Ling and Ring models and will continue evolving with community contributions.
Why Choose a Jax Backend?
Jax is designed from the ground up for TPU, providing uncompromised performance. Major AI labs (Google DeepMind, xAI, Anthropic, Apple) already rely on Jax, reducing maintenance overhead by using the same framework for training and inference.
Architecture
The stack is pure Jax with minimal dependencies. An OpenAI‑compatible API receives requests, uses SGLang’s RadixCache for prefix caching, and an overlapping scheduler for low‑overhead batching. Models are implemented in Flax with shard_map for parallelism, and custom Pallas kernels handle attention and MoE operators.
Key Optimizations
Ragged Paged Attention v3 Integration : Adjusted kernel grid configuration, added RadixCache compatibility, and custom masks for EAGLE speculative decoding.
Reduced Scheduling Overhead : Overlapped CPU preparation with TPU computation, shrinking batch gaps from milliseconds to microseconds.
MoE Kernel Improvements : Integrated Megablox GMM for EPMoE, achieving 3‑4× end‑to‑end speedup over Jax ragged_dot; also provided FusedMoE for dense expert scenarios.
Speculative Decoding : Implemented EAGLE‑based multi‑token prediction with tree‑based MTP‑Verify, adding non‑causal mask support on top of RPA v3.
Performance
Benchmarks show SGLang-Jax matches or exceeds other TPU inference solutions and remains competitive against GPU‑based alternatives.
Installation & Usage
Install via uv or from source:
# with uv
uv venv --python 3.12 && source .venv/bin/activate
uv pip install sglang-jax
# from source
git clone https://github.com/sgl-project/sglang-jax
cd sglang-jax
uv venv --python 3.12 && source .venv/bin/activate
uv pip install -e python/Launch the server:
MODEL_NAME="Qwen/Qwen3-8B" # or "Qwen/Qwen3-32B"
jax_COMPILATION_CACHE_DIR=/tmp/jit_cache \
uv run python -u -m sgl_jax.launch_server \
--model-path ${MODEL_NAME} \
--trust-remote-code \
--tp-size=4 \
--device=tpu \
--mem-fraction-static=0.8 \
--chunked-prefill-size=2048 \
--download-dir=/tmp \
--dtype=bfloat16 \
--max-running-requests=256 \
--page-size=128Running on TPU via GCP or SkyPilot
Use the GCP console to create a TPU (v2‑alpha‑tpuv6e) and set up SSH keys, or employ SkyPilot for automated, spot‑instance provisioning:
sky launch sgl-jax.sky.yaml --cluster=sgl-jax-skypilot-v6e-4 --infra=gcp -i 30 --down -y --use-spotRoadmap
Model support and optimization (Grok2, Ling/Ring, DeepSeek V3, GPT‑OSS, MiMo‑Audio, Wan 2.1, Qwen3 VL)
TPU‑optimized kernels (quantization, communication‑compute overlap, MLA)
RL integration with tunix
Pathways and multi‑host support
Advanced service features (prefill‑decode decoupling, hierarchical KV cache, multi‑LoRA batching)
Acknowledgments
Thanks to the SGLang‑jax team, Google contributors, and InclusionAI members for their work.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
