Skip to content

Commit

Permalink
Add aux_outputs for CPO and SimPO (#492)
Browse files Browse the repository at this point in the history
## Summary
In trl
[implementation](https://github.com/huggingface/trl/blob/8c49ea39ec2e11ce4c61291ff2ad59edb46522aa/trl/trainer/cpo_trainer.py#L669C1-L672C56),
CPO should have 2 extra return values (`chosen_rewards`,
`rejected_rewards`), but this is not implemented in Liger-kernel.
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Mecoli1219 <[email protected]>
Co-authored-by: Austin Liu <[email protected]>
  • Loading branch information
Mecoli1219 and austin362667 authored Jan 8, 2025
1 parent 5c5a7b4 commit f8bb86f
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 3 deletions.
6 changes: 5 additions & 1 deletion src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, labe
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
full_target.shape[0] // 2
)
return loss

chosen_rewards = beta * chosen_logps
rejected_rewards = beta * rejected_logps

return loss, chosen_rewards, rejected_rewards

@staticmethod
def forward(
Expand Down
5 changes: 4 additions & 1 deletion src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def preference_loss_fn(
full_target.shape[0] // 2
)

return loss
chosen_rewards = beta * chosen_logps
rejected_rewards = beta * rejected_logps

return loss, chosen_rewards, rejected_rewards

@staticmethod
def forward(
Expand Down
6 changes: 5 additions & 1 deletion test/chunked_loss/test_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def alignment_loss(
)
else:
raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']")
return losses

chosen_rewards = self.beta * policy_chosen_logps
rejected_rewards = self.beta * policy_rejected_logps

return losses, chosen_rewards, rejected_rewards


class TorchLMHeadCPO(torch.nn.Module):
Expand Down

0 comments on commit f8bb86f

Please sign in to comment.