Boosting RLHF Training Efficiency with Asynchronous vLLM and Ray Integration
This article explains how an asynchronous RLHF pipeline built on vLLM, Ray, and OpenRLHF dramatically reduces training bottlenecks by decoupling inference, environment interaction, and model updates, and provides detailed implementation code and design choices for scalable reinforcement learning.
Introduction
Reinforcement Learning from Human Feedback (RLHF) and its variant RLVR have become hot topics, yet open‑source frameworks that efficiently handle the heavy interaction and long‑output inference costs are still rare. By extending the OpenRLHF codebase, the author created the first official RLHF framework that natively supports asynchronous agent training, dramatically improving overall throughput.
Async vLLM Inference and Environment Interaction (LLMRayActorAsync)
The core module vllm_engine_async.py uses Python asyncio to run multiple agents concurrently with the vLLM AsyncLLM engine. The main function add_requests launches asynchronous calls to execute_agent, which in turn invokes the remote Ray actor AgentInstance that wraps the user‑provided step function. Concurrency is limited by an asyncio.Semaphore, and results are collected via a shared result_queue. Unlike traditional HTTP‑based communication, the pipeline passes the generate_async function directly into the event loop, eliminating extra latency.
vLLM’s AsyncLLM interface ensures non‑blocking inference, greatly speeding up response time.
import asyncio
import os
import ray
from .vllm_engine import BaseLLMRayActor
@ray.remote
class AgentInstance:
def __init__(self, agent_func_path):
if agent_func_path.endswith(".py"):
import importlib.util
spec = importlib.util.spec_from_file_location("step", agent_func_path)
agent_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(agent_module)
self.agent_step = agent_module.step
else:
raise ValueError("Agent path must be a Python file")
async def step(self, state, action, label):
return await self.agent_step(state, action, label)
@ray.remote
class LLMRayActorAsync(BaseLLMRayActor):
async def add_requests(self, sampling_params, prompts, labels, max_length, hf_tokenizer=None, max_steps=10000):
"""Process requests from rank0 and generate responses with multiple agent interactions.
Args:
sampling_params: Parameters for sampling
prompts: List of prompts to process
labels: List of labels corresponding to prompts
max_steps: Maximum number of interaction steps
"""
NUM_TASKS = os.environ.get("OPENRLHF_ASYNC_NUM_TASKS", 128)
semaphore = asyncio.Semaphore(NUM_TASKS)
async def execute_agent(prompt, label, sampling_params):
async with semaphore:
agent_instance = AgentInstance.remote(self.agent_func_path)
state = prompt
action_ranges = []
total_reward = 0
for step_idx in range(max_steps):
state_tokens_len = len(hf_tokenizer(state, add_special_tokens=False, return_tensors="pt")["input_ids"][0])
sampling_params.max_tokens = max_length - state_tokens_len
if sampling_params.max_tokens <= 0:
break
request_output = await self.generate_async(state, sampling_params)
action = request_output.outputs[0].text
action_ranges.append((len(state), len(state) + len(action)))
reward, state, done, extra_info = await agent_instance.step.remote(state, action, label)
total_reward += reward.item()
if done:
break
ray.kill(agent_instance)
final_response = {
"prompt": prompt,
"label": label,
"state": state,
"reward": total_reward,
"action_ranges": action_ranges,
}
await self.result_queue.put(final_response)
import copy
tasks = []
for prompt, label in zip(prompts, labels):
tasks.append(execute_agent(prompt, label, copy.deepcopy(sampling_params)))
await asyncio.gather(*tasks)
async def generate_async(self, prompts, sampling_params):
from vllm.utils import random_uuid
request_id = random_uuid()
results_generator = self.llm.generate(prompts, sampling_params, request_id)
final_output = None
async for request_output in results_generator:
final_output = request_output
return final_output
async def get_responses(self):
results = []
while not self.result_queue.empty():
try:
results.append(await self.result_queue.get())
except asyncio.QueueEmpty:
break
return resultsAsync Training Pipeline: PPOTrainerAsync
The training process is split into two Ray actors: GenerateSamplesActor (sample generation) and TrainingActor (policy update). They communicate through a Ray Queue whose size is set to 1 to keep the off‑policy gap small and maintain stability. The queue also controls the degree of asynchrony.
@ray.remote
class GenerateSamplesActor(BasePPOTrainer):
def generate_samples(self, prompts, labels, **generate_kwargs):
return self.samples_generator.generate_samples(prompts, labels, **generate_kwargs)
def fit(self, start_episode, consumed_samples, queue):
for episode in range(start_episode, self.args.num_episodes):
# data loading logic omitted for brevity
rollout_samples = self.generate_samples(rand_prompts, labels, **self.generate_kwargs)
queue.put(rollout_samples)
queue.put("done")
@ray.remote
class TrainingActor(BasePPOTrainer):
def fit(self, queue, steps, pbar_steps):
pbar = tqdm(range(pbar_steps), desc="Training Process", disable=False)
while True:
rollout_samples = queue.get()
if rollout_samples == "done":
break
experiences = self.experience_maker.make_experience_list(rollout_samples)
# async model forward passes
refs = self.actor_model_group.async_run_method_batch(method_name="append", experience=experiences)
if self.critic_model_group is not None:
refs.extend(self.critic_model_group.async_run_method_batch(method_name="append", experience=experiences))
ray.get(refs)
status = self.ppo_train(steps)
# additional logging omitted
@ray.remote
class PPOTrainerAsync:
def __init__(self, ...):
from ray.util.queue import Queue
self.queue = Queue(maxsize=os.environ.get("OPENRLHF_ASYNC_QUEUE_SIZE", 1))
# other initializations omitted
def fit(self) -> None:
generator_actor_ref = self.generator_actor.fit.remote(start_episode, consumed_samples, self.queue)
trainer_actor_ref = self.trainer_actor.fit.remote(self.queue, steps, pbar_steps)
ray.get([generator_actor_ref, trainer_actor_ref])Choosing an Async RL Algorithm
Empirical tests with community partners suggest that for asynchronous, off‑policy scenarios the combination of REINFORCE++ with a baseline or GRPO with Dynamic Sampling provides the most stable results. It is also recommended to keep the async sample size moderate to avoid destabilizing the training dynamics.
Acknowledgments
Technical support for the vLLM Async Engine: https://www.zhihu.com/people/176cf88046a1cae595b55e12d58c95e9
Original pipeline design ideas: https://www.zhihu.com/people/8b9fca5f4a5f47e58e9c2e0aeac7065f
Baobao Algorithm Notes
Author of the BaiMian large model, offering technology and industry insights.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
