diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 1cbb0c36c..1723b5541 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -45,9 +45,12 @@ def preference_loss_fn( chosen_logratios = chosen_logps - ref_chosen_logps rejected_logratios = rejected_logps - ref_rejected_logps + chosen_rewards = beta * (chosen_logps - ref_chosen_logps) + rejected_rewards = beta * (rejected_logps - ref_rejected_logps) + logits_diff = beta * (chosen_logratios - rejected_logratios) loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2) - return loss + return loss, chosen_rewards, rejected_rewards @staticmethod def forward( diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index ab18a0f24..895c02dea 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -56,9 +56,12 @@ def alignment_loss( chosen_logratios = policy_chosen_logps - ref_chosen_logps rejected_logratios = policy_rejected_logps - ref_rejected_logps + chosen_rewards = self.beta * (policy_chosen_logps - ref_chosen_logps) + rejected_rewards = self.beta * (policy_rejected_logps - ref_rejected_logps) + logits_diff = self.beta * (chosen_logratios - rejected_logratios) losses = -F.logsigmoid(logits_diff) - return losses + return losses, chosen_rewards, rejected_rewards class TorchLMHeadDPO(torch.nn.Module):