Demystifying OpenRLHF Loss Functions: From GPTLM to KTO and Beyond

This article walks through the various loss functions used in OpenRLHF—including GPTLMLoss, KDLoss, DPOLoss, KTOLoss, and reward model losses—explaining their mathematical foundations, implementation details, and practical considerations for RLHF training.

Baobao Algorithm Notes
Baobao Algorithm Notes
Baobao Algorithm Notes
Demystifying OpenRLHF Loss Functions: From GPTLM to KTO and Beyond

Introduction

The author shares a series of technical notes on learning RLHF with OpenRLHF, focusing first on the loss functions that drive model training.

Basic Concepts

Before diving into code, it is useful to memorize the core formulas and visualizations of loss and its gradient, illustrated by the following images.

Loss diagram
Loss diagram
Loss gradient diagram
Loss gradient diagram

SFT Family

GPTLMLoss

class GPTLMLoss(nn.Module):
    """GPT Language Model Loss"""
    def __init__(self):
        super().__init__()
        self.IGNORE_INDEX = -100
        self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
    def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

This is the standard GPT pre‑training / SFT loss that masks prompt tokens using IGNORE_INDEX.

KDLoss

# Adapted from https://github.com/microsoft/LMOps/blob/main/minillm/finetune.py#L166
class KDLoss(nn.Module):
    """Language Model Knowledge Distillation Loss"""
    def __init__(self):
        super().__init__()
        self.IGNORE_INDEX = -100
    def forward(self, logits: torch.Tensor, teacher_logits: torch.Tensor, label: torch.Tensor) -> torch.Tensor:
        teacher_probs = F.softmax(teacher_logits, dim=-1, dtype=torch.float32)
        inf_mask = torch.isinf(logits)
        logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
        prod_probs = torch.masked_fill(teacher_probs * logprobs, inf_mask, 0)
        x = torch.sum(prod_probs, dim=-1).view(-1)
        mask = (label != self.IGNORE_INDEX).int()
        distil_loss = -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)
        return distil_loss

This loss implements knowledge distillation by aligning the student model’s logits with a teacher model’s probability distribution, using KL‑divergence‑style computation. A combined loss can be formed as:

lm_loss = F.cross_entropy(
    logits.view(-1, logits.size(-1)),
    label.view(-1),
    ignore_index=self.IGNORE_INDEX
)
total_loss = alpha * lm_loss + beta * distil_loss

DPO Family

DPOLoss

class DPOLoss(nn.Module):
    """DPO Loss"""
    def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
        super().__init__()
        self.beta = beta
        self.label_smoothing = label_smoothing
        self.ipo = ipo
    def forward(
        self,
        policy_chosen_logps: torch.Tensor,
        policy_rejected_logps: torch.Tensor,
        reference_chosen_logps: torch.Tensor,
        reference_rejected_logps: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        pi_logratios = policy_chosen_logps - policy_rejected_logps
        ref_logratios = reference_chosen_logps - reference_rejected_logps
        logits = pi_logratios - ref_logratios
        if self.ipo:
            losses = (logits - 1 / (2 * self.beta)) ** 2  # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
        else:
            # Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
            losses = (
                -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
                - F.logsigmoid(-self.beta * logits) * self.label_smoothing
            )
        loss = losses.mean()
        chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
        return loss, chosen_rewards, rejected_rewards

DPO’s loss can be used directly or with two optional extensions: IPO (adds a regularization term) and CDPO (applies label smoothing).

KTOLoss

# Adapted from https://github.com/ContextualAI/HALOs/blob/ca9b7e3eeea220c0944ad8095d641da33f907a7e/trainers.py#L770
class KTOLoss(nn.Module):
    """KTO loss for uneven sampling"""
    def __init__(self, beta: float, desirable_weight: float, undesirable_weight: float, world_size: int, device: torch.device) -> None:
        super().__init__()
        self.beta = beta
        self.world_size = world_size
        self.device = device
        self.desirable_weight = desirable_weight
        self.undesirable_weight = undesirable_weight
    def forward(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        policy_KL_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
        reference_KL_logps: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        KL = (policy_KL_logps - reference_KL_logps).mean().detach()
        dist.all_reduce(KL, op=dist.ReduceOp.SUM)
        KL = (KL / self.world_size).clamp(min=0)
        if policy_chosen_logps.shape[0] != 0:
            chosen_logratios = policy_chosen_logps - reference_chosen_logps
            chosen_losses = 1 - F.sigmoid(self.beta * (chosen_logratios - KL))
            chosen_rewards = self.beta * chosen_logratios.detach()
        else:
            chosen_losses = torch.Tensor([]).to(policy_rejected_logps.dtype).to(self.device)
            chosen_rewards = torch.Tensor([]).to(policy_rejected_logps.dtype).to(self.device)
        if policy_rejected_logps.shape[0] != 0:
            rejected_logratios = policy_rejected_logps - reference_rejected_logps
            rejected_losses = 1 - F.sigmoid(self.beta * (KL - rejected_logratios))
            rejected_rewards = self.beta * rejected_logratios.detach()
        else:
            rejected_losses = torch.Tensor([]).to(policy_chosen_logps.dtype).to(self.device)
            rejected_rewards = torch.Tensor([]).to(policy_chosen_logps.dtype).to(self.device)
        losses = torch.cat((self.desirable_weight * chosen_losses, self.undesirable_weight * rejected_losses), 0).mean()
        return losses, chosen_rewards, rejected_rewards, KL

KTO works with uneven sampling; it averages KL across devices and uses separate weights for desirable and undesirable samples.

VanillaKTOLoss

# Adapted from https://github.com/ContextualAI/HALOs/blob/ca9b7e3eeea220c0944ad8095d641da33f907a7e/trainers.py#L742
class VanillaKTOLoss(nn.Module):
    """KTO loss for even sampling"""
    def __init__(self, beta: float) -> None:
        super().__init__()
        self.beta = beta
    def forward(
        self,
        policy_chosen_logps: torch.FloatTensor,
        policy_rejected_logps: torch.FloatTensor,
        reference_chosen_logps: torch.FloatTensor,
        reference_rejected_logps: torch.FloatTensor,
    ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
        chosen_KL = (policy_chosen_logps - reference_chosen_logps).mean().clamp(min=0)
        rejected_KL = (policy_rejected_logps - reference_rejected_logps).mean().clamp(min=0)
        chosen_logratios = policy_chosen_logps - reference_chosen_logps
        rejected_logratios = policy_rejected_logps - reference_rejected_logps
        losses = torch.cat((
            1 - F.sigmoid(self.beta * (chosen_logratios - rejected_KL)),
            1 - F.sigmoid(self.beta * (chosen_KL - rejected_logratios)),
        ), 0).mean()
        chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
        rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
        return losses, chosen_rewards, rejected_rewards

Vanilla KTO removes the reference KL term and assumes balanced positive/negative samples.

RLHF Family

PolicyLoss

class PolicyLoss(nn.Module):
    """Policy Loss for PPO"""
    def __init__(self, clip_eps: float = 0.2) -> None:
        super().__init__()
        self.clip_eps = clip_eps
    def forward(
        self,
        log_probs: torch.Tensor,
        old_log_probs: torch.Tensor,
        advantages: torch.Tensor,
        action_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        ratio = (log_probs - old_log_probs).exp()
        surr1 = ratio * advantages
        surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
        loss = -torch.min(surr1, surr2)
        loss = masked_mean(loss, action_mask, dim=-1).mean()
        return loss

The PPO loss computes the probability ratio, scales it by the advantage, clips the ratio, and takes the minimum of clipped and unclipped objectives.

ValueLoss

class ValueLoss(nn.Module):
    """Value Loss for PPO"""
    def __init__(self, clip_eps: float = None) -> None:
        super().__init__()
        self.clip_eps = clip_eps
    def forward(
        self,
        values: torch.Tensor,
        old_values: torch.Tensor,
        returns: torch.Tensor,
        action_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        if self.clip_eps is not None:
            values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
            surr1 = (values_clipped - returns) ** 2
            surr2 = (values - returns) ** 2
            loss = torch.max(surr1, surr2)
        else:
            loss = (values - returns) ** 2
        loss = masked_mean(loss, action_mask, dim=-1).mean()
        return 0.5 * loss

Clipping the value function prevents large jumps in the critic’s estimates; the loss picks the larger of clipped and unclipped squared errors.

PairWiseLoss

class PairWiseLoss(nn.Module):
    """Pairwise Loss for Reward Model"""
    def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor, margin: torch.Tensor = None) -> torch.Tensor:
        if margin is not None:
            loss = -F.logsigmoid(chosen_reward - reject_reward - margin)
        else:
            loss = -F.logsigmoid(chosen_reward - reject_reward)
        return loss.mean()

This loss encourages the chosen response to have higher reward than the rejected one, optionally with a margin.

LogExpLoss

class LogExpLoss(nn.Module):
    """Pairwise Loss for Reward Model (log‑exp formulation)"""
    def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor, margin: torch.Tensor = None) -> torch.Tensor:
        loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
        return loss

Mathematically equivalent to PairWiseLoss but expressed with a log‑exp function.

PRMLoss

class PRMLoss(nn.Module):
    """Process Reward Model Loss"""
    def __init__(self, placeholder_token_id: int, reward_token_ids: Optional[list[int]] = None):
        super().__init__()
        self.IGNORE_INDEX = -100
        self.loss = nn.CrossEntropyLoss(ignore_index=self.IGNORE_INDEX)
        self.placeholder_token_id = placeholder_token_id
        self.reward_token_ids = reward_token_ids
    def forward(self, inputs: torch.Tensor, logits: torch.Tensor, labels: torch.Tensor, *, return_acc: bool = False):
        placeholder_mask = inputs == self.placeholder_token_id
        logits = logits[placeholder_mask]
        labels = labels[placeholder_mask]
        if labels.dtype == torch.float:
            # soft label handling
            assert len(self.reward_token_ids) == 2, "reward_token_ids should have 2 tokens for soft labels"
            logits = logits[..., self.reward_token_ids]
            positive_labels = labels.to(logits.dtype)
            negative_labels = 1 - positive_labels
            negative_labels[positive_labels != -100] = 1 - positive_labels[positive_labels != -100]
            labels = torch.stack([positive_labels, negative_labels], dim=-1)
        elif self.reward_token_ids is not None:
            # hard label handling
            logits = logits[..., self.reward_token_ids]
            for i, token in enumerate(self.reward_token_ids):
                labels = torch.where(labels == token, i, labels)
        loss = self.loss(logits, labels)
        if not return_acc:
            return loss
        if labels.dtype == logits.dtype:
            labels = labels.argmax(dim=-1)
        acc = (logits.argmax(dim=-1) == labels).float().mean()
        return loss, acc

PRM processes inputs containing special placeholder tokens (e.g., "ки") and supports both hard and soft label modes. An example dataset shows how steps are tokenized and labeled, and the implementation extracts logits for the reward tokens before computing cross‑entropy loss and optional accuracy.

The article concludes that a thorough understanding of these loss functions requires reading the corresponding trainer files (e.g., kto_trainer.py, ppo_trainer.py, prm_trainer.py) to see how tensors are prepared and fed into each loss.

PyTorchRLHFknowledge distillationDPOLoss FunctionsKTOOpenRLHF
Baobao Algorithm Notes
Written by

Baobao Algorithm Notes

Author of the BaiMian large model, offering technology and industry insights.

0 followers
Reader feedback

How this landed with the community

Sign in to like

Rate this article

Was this worth your time?

Sign in to rate
Discussion

0 Comments

Thoughtful readers leave field notes, pushback, and hard-won operational detail here.