Skip to content

Commit

Permalink
Add average_log_prob args for cpo (#510)
Browse files Browse the repository at this point in the history
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
`trl` CPO implementation didn't average the log_probs, while the liger
kernel averages it when computing the loss. This will cause a mismatch
when integrating them.
<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
- [x] Updating unit test (still investigating why unit test fail
locally)
<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence

---------

Signed-off-by: Mecoli1219 <[email protected]>
Co-authored-by: Kashif Rasul <[email protected]>
Co-authored-by: Austin Liu <[email protected]>
  • Loading branch information
3 people authored Jan 8, 2025
1 parent f8bb86f commit 23e3772
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def forward(
beta=beta,
label_smoothing=label_smoothing,
compute_nll_loss=compute_nll_loss,
average_log_prob=False,
compiled=compiled,
)

Expand Down
17 changes: 14 additions & 3 deletions src/liger_kernel/chunked_loss/fused_linear_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def forward(
ref_input=None,
ref_weight=None,
ref_bias=None,
average_log_prob=True,
**loss_kwargs,
):
"""
Expand Down Expand Up @@ -61,6 +62,7 @@ def forward(
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
average_log_prob (bool): Whether to average log probabilities or to sum them over the completion.
loss_kwargs (dict): Other possible arguments that a loss function might need
"""
# TODO: Tune CHUNK_SIZE to fully utilize the GPU
Expand Down Expand Up @@ -94,6 +96,7 @@ def forward(
use_ref_model=use_ref_model,
ref_weight=ref_weight,
ref_bias=ref_bias,
average_log_prob=average_log_prob,
**loss_kwargs,
)

Expand Down Expand Up @@ -265,6 +268,7 @@ def chunk_forward(
bias=None,
ignore_index=-100,
compute_nll_loss=True,
average_log_prob=True,
):
len_chosen_chunk = target_chunk.shape[0] // 2
logits_chunk = input_chunk @ weight.t()
Expand All @@ -285,10 +289,13 @@ def chunk_forward(
label_chunk = torch.where(loss_mask, target_chunk, 0)

per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1)
average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
if average_log_prob:
log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
log_prob = (per_token_logps * loss_mask).sum(-1)

chosen_logps = average_log_prob[:len_chosen_chunk]
rejected_logps = average_log_prob[len_chosen_chunk:]
chosen_logps = log_prob[:len_chosen_chunk]
rejected_logps = log_prob[len_chosen_chunk:]

chosen_logits = logits_chunk[:len_chosen_chunk]
rejected_logits = logits_chunk[len_chosen_chunk:]
Expand Down Expand Up @@ -317,6 +324,7 @@ def _compute_loss(
ref_input_chunk=None,
ref_weight=None,
ref_bias=None,
average_log_prob=True,
**loss_kwargs,
):
"""
Expand All @@ -335,6 +343,7 @@ def _compute_loss(
use_ref_model (bool): Whether to use a reference model for the alignment loss.
ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size).
ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,).
average_log_prob (bool): Whether to average log probabilities or the sum.
loss_kwargs (dict): Additional arguments for the loss function.
"""
(
Expand All @@ -350,6 +359,7 @@ def _compute_loss(
bias=bias,
ignore_index=ignore_index,
compute_nll_loss=compute_nll_loss,
average_log_prob=average_log_prob,
)
chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum()
chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0])
Expand All @@ -372,6 +382,7 @@ def _compute_loss(
ref_bias,
ignore_index=ignore_index,
compute_nll_loss=False, # We don't need NLL loss for the reference model
average_log_prob=average_log_prob,
)
loss_kwargs["ref_chosen_logps"] = ref_chosen_logps
loss_kwargs["ref_rejected_logps"] = ref_rejected_logps
Expand Down
5 changes: 3 additions & 2 deletions test/chunked_loss/test_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,10 @@ def __init__(
label_smoothing=label_smoothing,
simpo_gamma=simpo_gamma,
).get_batch_loss_metrics
self.average_log_prob = loss_type == "simpo"

def forward(self, x, y):
return self.cpo_loss(self.lin.weight, x, y, self.lin.bias)
return self.cpo_loss(self.lin.weight, x, y, self.lin.bias, average_log_prob=self.average_log_prob)


class LigerLMHeadCPO(torch.nn.Module):
Expand Down Expand Up @@ -143,7 +144,7 @@ def forward(self, x, y):
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-3, 5e-3),
(1.0, torch.bfloat16, 5e-2, 5e-2),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
Expand Down

0 comments on commit 23e3772

Please sign in to comment.