From f8bb86fa5d351d3965cbea700a9234b984e52ff5 Mon Sep 17 00:00:00 2001 From: Chun-Mao Lai <72752478+Mecoli1219@users.noreply.github.com> Date: Tue, 7 Jan 2025 23:28:19 -0800 Subject: [PATCH] Add `aux_outputs` for CPO and SimPO (#492) ## Summary In trl [implementation](https://github.com/huggingface/trl/blob/8c49ea39ec2e11ce4c61291ff2ad59edb46522aa/trl/trainer/cpo_trainer.py#L669C1-L672C56), CPO should have 2 extra return values (`chosen_rewards`, `rejected_rewards`), but this is not implemented in Liger-kernel. ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Mecoli1219 Co-authored-by: Austin Liu --- src/liger_kernel/chunked_loss/cpo_loss.py | 6 +++++- src/liger_kernel/chunked_loss/simpo_loss.py | 5 ++++- test/chunked_loss/test_cpo_loss.py | 6 +++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 41ec78a9d..aa7f078ba 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -33,7 +33,11 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, labe loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / ( full_target.shape[0] // 2 ) - return loss + + chosen_rewards = beta * chosen_logps + rejected_rewards = beta * rejected_logps + + return loss, chosen_rewards, rejected_rewards @staticmethod def forward( diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index 975bcefab..7f7c75bc1 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -42,7 +42,10 @@ def preference_loss_fn( full_target.shape[0] // 2 ) - return loss + chosen_rewards = beta * chosen_logps + rejected_rewards = beta * rejected_logps + + return loss, chosen_rewards, rejected_rewards @staticmethod def forward( diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index 5f5e15ad1..a59e3db19 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -73,7 +73,11 @@ def alignment_loss( ) else: raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']") - return losses + + chosen_rewards = self.beta * policy_chosen_logps + rejected_rewards = self.beta * policy_rejected_logps + + return losses, chosen_rewards, rejected_rewards class TorchLMHeadCPO(torch.nn.Module):