Demystifying OpenRLHF Loss Functions: From GPTLM to KTO and Beyond
This article walks through the various loss functions used in OpenRLHF—including GPTLMLoss, KDLoss, DPOLoss, KTOLoss, and reward model losses—explaining their mathematical foundations, implementation details, and practical considerations for RLHF training.
Introduction
The author shares a series of technical notes on learning RLHF with OpenRLHF, focusing first on the loss functions that drive model training.
Basic Concepts
Before diving into code, it is useful to memorize the core formulas and visualizations of loss and its gradient, illustrated by the following images.
SFT Family
GPTLMLoss
class GPTLMLoss(nn.Module):
"""GPT Language Model Loss"""
def __init__(self):
super().__init__()
self.IGNORE_INDEX = -100
self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))This is the standard GPT pre‑training / SFT loss that masks prompt tokens using IGNORE_INDEX.
KDLoss
# Adapted from https://github.com/microsoft/LMOps/blob/main/minillm/finetune.py#L166
class KDLoss(nn.Module):
"""Language Model Knowledge Distillation Loss"""
def __init__(self):
super().__init__()
self.IGNORE_INDEX = -100
def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
inf_mask = torch.isinf(logits)
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
x = torch.sum(prod_probs, dim=-1).view(-1)
mask = (label != self.IGNORE_INDEX).int()
distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
return distil_lossThis loss implements knowledge distillation by aligning the student model’s logits with a teacher model’s probability distribution, using KL‑divergence‑style computation. A combined loss can be formed as:
lm_loss = F.cross_entropy(
logits.view(-1, logits.size(-1)),
label.view(-1),
ignore_index=self.IGNORE_INDEX
)
total_loss = alpha * lm_loss + beta * distil_lossDPO Family
DPOLoss
class DPOLoss(nn.Module):
"""DPO Loss"""
def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
super().__init__()
self.beta = beta
self.label_smoothing = label_smoothing
self.ipo = ipo
def forward(
self,
policy_chosen_logps: torch.Tensor,
policy_rejected_logps: torch.Tensor,
reference_chosen_logps: torch.Tensor,
reference_rejected_logps: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
logits = pi_logratios - ref_logratios
if self.ipo:
losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
else:
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
losses = (
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
)
loss = losses.mean()
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
return loss, chosen_rewards, rejected_rewardsDPO’s loss can be used directly or with two optional extensions: IPO (adds a regularization term) and CDPO (applies label smoothing).
KTOLoss
# Adapted from https://github.com/ContextualAI/HALOs/blob/ca9b7e3eeea220c0944ad8095d641da33f907a7e/trainers.py#L770
class KTOLoss(nn.Module):
"""KTO loss for uneven sampling"""
def __init__(self, beta: float, desirable_weight: float, undesirable_weight: float, world_size: int, device: torch.device) -> None:
super().__init__()
self.beta = beta
self.world_size = world_size
self.device = device
self.desirable_weight = desirable_weight
self.undesirable_weight = undesirable_weight
def forward(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
policy_KL_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
reference_KL_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
KL = (policy_KL_logps - reference_KL_logps).mean().detach()
dist.all_reduce(KL, op=dist.ReduceOp.SUM)
KL = (KL / self.world_size).clamp(min=0)
if policy_chosen_logps.shape[0] != 0:
chosen_logratios = policy_chosen_logps - reference_chosen_logps
chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - KL))
chosen_rewards = self.beta * chosen_logratios.detach()
else:
chosen_losses = torch.Tensor([]).to(policy_rejected_logps.dtype).to(self.device)
chosen_rewards = torch.Tensor([]).to(policy_rejected_logps.dtype).to(self.device)
if policy_rejected_logps.shape[0] != 0:
rejected_logratios = policy_rejected_logps - reference_rejected_logps
rejected_losses = 1 - F.sigmoid(self.beta * (KL - rejected_logratios))
rejected_rewards = self.beta * rejected_logratios.detach()
else:
rejected_losses = torch.Tensor([]).to(policy_chosen_logps.dtype).to(self.device)
rejected_rewards = torch.Tensor([]).to(policy_chosen_logps.dtype).to(self.device)
losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
return losses, chosen_rewards, rejected_rewards, KLKTO works with uneven sampling; it averages KL across devices and uses separate weights for desirable and undesirable samples.
VanillaKTOLoss
# Adapted from https://github.com/ContextualAI/HALOs/blob/ca9b7e3eeea220c0944ad8095d641da33f907a7e/trainers.py#L742
class VanillaKTOLoss(nn.Module):
"""KTO loss for even sampling"""
def __init__(self, beta: float) -> None:
super().__init__()
self.beta = beta
def forward(
self,
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)
chosen_logratios = policy_chosen_logps - reference_chosen_logps
rejected_logratios = policy_rejected_logps - reference_rejected_logps
losses = torch.cat((
1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
), 0).mean()
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
return losses, chosen_rewards, rejected_rewardsVanilla KTO removes the reference KL term and assumes balanced positive/negative samples.
RLHF Family
PolicyLoss
class PolicyLoss(nn.Module):
"""Policy Loss for PPO"""
def __init__(self, clip_eps: float = 0.2) -> None:
super().__init__()
self.clip_eps = clip_eps
def forward(
self,
log_probs: torch.Tensor,
old_log_probs: torch.Tensor,
advantages: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
ratio = (log_probs - old_log_probs).exp()
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
loss = masked_mean(loss, action_mask, dim=-1).mean()
return lossThe PPO loss computes the probability ratio, scales it by the advantage, clips the ratio, and takes the minimum of clipped and unclipped objectives.
ValueLoss
class ValueLoss(nn.Module):
"""Value Loss for PPO"""
def __init__(self, clip_eps: float = None) -> None:
super().__init__()
self.clip_eps = clip_eps
def forward(
self,
values: torch.Tensor,
old_values: torch.Tensor,
returns: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if self.clip_eps is not None:
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2
loss = torch.max(surr1, surr2)
else:
loss = (values - returns) ** 2
loss = masked_mean(loss, action_mask, dim=-1).mean()
return 0.5 * lossClipping the value function prevents large jumps in the critic’s estimates; the loss picks the larger of clipped and unclipped squared errors.
PairWiseLoss
class PairWiseLoss(nn.Module):
"""Pairwise Loss for Reward Model"""
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor, margin: torch.Tensor = None) -> torch.Tensor:
if margin is not None:
loss = -F.logsigmoid(chosen_reward - reject_reward - margin)
else:
loss = -F.logsigmoid(chosen_reward - reject_reward)
return loss.mean()This loss encourages the chosen response to have higher reward than the rejected one, optionally with a margin.
LogExpLoss
class LogExpLoss(nn.Module):
"""Pairwise Loss for Reward Model (log‑exp formulation)"""
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor, margin: torch.Tensor = None) -> torch.Tensor:
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
return lossMathematically equivalent to PairWiseLoss but expressed with a log‑exp function.
PRMLoss
class PRMLoss(nn.Module):
"""Process Reward Model Loss"""
def __init__(self, placeholder_token_id: int, reward_token_ids: Optional[list[int]] = None):
super().__init__()
self.IGNORE_INDEX = -100
self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
self.placeholder_token_id = placeholder_token_id
self.reward_token_ids = reward_token_ids
def forward(self, inputs: torch.Tensor, logits: torch.Tensor, labels: torch.Tensor, *, return_acc: bool = False):
placeholder_mask = inputs == self.placeholder_token_id
logits = logits[placeholder_mask]
labels = labels[placeholder_mask]
if labels.dtype == torch.float:
# soft label handling
assert len(self.reward_token_ids) == 2, "reward_token_ids should have 2 tokens for soft labels"
logits = logits[..., self.reward_token_ids]
positive_labels = labels.to(logits.dtype)
negative_labels = 1 - positive_labels
negative_labels[positive_labels != -100] = 1 - positive_labels[positive_labels != -100]
labels = torch.stack([positive_labels, negative_labels], dim=-1)
elif self.reward_token_ids is not None:
# hard label handling
logits = logits[..., self.reward_token_ids]
for i, token in enumerate(self.reward_token_ids):
labels = torch.where(labels == token, i, labels)
loss = self.loss(logits, labels)
if not return_acc:
return loss
if labels.dtype == logits.dtype:
labels = labels.argmax(dim=-1)
acc = (logits.argmax(dim=-1) == labels).float().mean()
return loss, accPRM processes inputs containing special placeholder tokens (e.g., "ки") and supports both hard and soft label modes. An example dataset shows how steps are tokenized and labeled, and the implementation extracts logits for the reward tokens before computing cross‑entropy loss and optional accuracy.
The article concludes that a thorough understanding of these loss functions requires reading the corresponding trainer files (e.g., kto_trainer.py, ppo_trainer.py, prm_trainer.py) to see how tensors are prepared and fed into each loss.
Baobao Algorithm Notes
Author of the BaiMian large model, offering technology and industry insights.
How this landed with the community
Was this worth your time?
0 Comments
Thoughtful readers leave field notes, pushback, and hard-won operational detail here.
