How Awex Enables Sub‑Second TB‑Scale Weight Sync for Trillion‑Parameter RL Models
Awex is a high‑performance Python framework that synchronizes training and inference weights for trillion‑parameter reinforcement‑learning models in seconds, using unified conversion, metadata management, and NCCL/RDMA transfer plans, dramatically reducing RL training latency and supporting diverse parallel strategies.
Background
Reinforcement learning (RL) is a core technique for extending large language models. Scaling RL to trillion‑parameter models (e.g., Ling‑1T, Ring‑1T) creates system‑level challenges, especially weight synchronization between training and inference engines.
Key Challenges
Heterogeneous parallel strategies : Training may use DP+PP/VPP+TP+EP+CP, while inference typically uses DP+TP+EP, requiring efficient resharding.
Precision and format mismatch : Inference often runs in low precision (FP8/INT4) whereas training uses higher precision (BF16/FP32), demanding format conversion.
Massive data volume : Trillion‑parameter models generate >100 TB of weight data on thousand‑GPU clusters, making transfer a bottleneck.
Co‑locate vs. separate deployment : RL workloads may share GPUs (co‑locate) or run on separate cards, each with different synchronization needs.
Dynamic inference scaling : As rollout steps increase, inference instances must be added, requiring weight sync to support new instances.
Limitations of Existing Solutions
File‑system based exchange: Writing full checkpoints to a shared filesystem takes minutes for 10 B models and hours for trillion‑parameter models.
Layer‑wise NCCL + full‑tensor sync: Serial per‑layer transfers under‑utilize bandwidth and cause minute‑level delays for large models.
Awex Design
Awex is a pure‑Python library that provides ultra‑low‑latency weight exchange for RL pipelines. It consists of three core components:
WeightWriter : Runs inside each training process, collects shard metadata, performs weight conversion, builds resharding plans, and sends weights.
WeightReader : Runs on each inference instance, mirrors WeightWriter functionality to receive and apply weight shards.
MetaServer : Global job‑level service for service discovery, metadata exchange, and event notification in co‑locate scenarios.
Unified Weight Conversion
Awex provides a pluggable conversion layer that normalizes weights across differing parallel strategies and tensor layouts. It handles:
Splitting merged weights (e.g., FFN gate/up) for TP‑aware resharding.
Standardizing weight names to a common namespace.
Regrouping QKV tensors to match inference TP/DP attention strategies.
Quantization and precision conversion from training to inference formats, reducing transfer volume.
Global Metadata Management
Each process gathers shard metadata, performs an all_gather_object to obtain a global view, and reports it to the MetaServer. Consistency checks ensure that training and inference shards are compatible before transfer.
P2P Transfer Planning
Awex builds deterministic point‑to‑point (P2P) plans based on global metadata. In NCCL mode, the plan includes round‑robin shard assignment, handling of overlapping shard intervals, and process‑local filtering to avoid constructing a full trillion‑shard plan.
NCCL vs. RDMA Transfer
NCCL : Uses send/recv APIs; straightforward but requires matching NCCL versions and static topologies.
RDMA : Decouples NCCL versions, offers flexible load‑balanced plans, and supports dynamic inference scaling. RDMA reduces 1 T model transfer time from ~20 s (NCCL) to ~6 s.
CUDA IPC for Co‑locate Zero‑Copy
When training and inference share GPUs, Awex employs CUDA IPC to map training memory directly into inference processes, avoiding extra copies. To mitigate IPC handle overhead, tensors are merged by shape and dtype before serialization.
Performance Results
On a thousand‑GPU cluster, Awex achieves the following transfer times:
10 B parameters (31 GB): NCCL 0.8 s, RDMA 0.5 s.
100 B parameters (191 GB): NCCL 9 s, RDMA 3.2 s.
1 T parameters (1000 GB, FP8): NCCL 20 s, RDMA 6 s.
Getting Started
Awex can be installed with a single pip command and supports Python 3.8+. pip install awex Example for Megatron training engine:
from awex import NCCLWeightsWriter
from awex.engine.mcore import MegatronEngine
train_engine = MegatronEngine(awex_config, hf_config, mcore_model)
writer = NCCLWeightsWriter(train_engine)
writer.initialize()
writer.write_weights(step_id=1)Example for SGLang inference engine:
from awex import WeightsReader
from awex.engine.sglang import SGLangEngine
import sglang as sgl
sgl_engine = sgl.Engine(model_path="xxx", tp_size=2, random_seed=42)
inference_engine = SGLangEngine(awex_config, sgl_engine)
reader = WeightsReader(inference_engine)
reader.initialize()
reader.update_weights(step_id=1)Future Work
Awex will be extended to support additional training engines (e.g., DeepSpeed, FSDP) and inference engines (e.g., vLLM, TensorRT‑LLM), and more model adapters will be added.
References
GitHub repository: https://github.com/inclusionAI/awex
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
