diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index aa7f078ba..5a83ab026 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -65,6 +65,7 @@ def forward( beta=beta, label_smoothing=label_smoothing, compute_nll_loss=compute_nll_loss, + average_log_prob=False, compiled=compiled, ) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index f389050c0..31d95480c 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -32,6 +32,7 @@ def forward( ref_input=None, ref_weight=None, ref_bias=None, + average_log_prob=True, **loss_kwargs, ): """ @@ -61,6 +62,7 @@ def forward( use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + average_log_prob (bool): Whether to average log probabilities or to sum them over the completion. loss_kwargs (dict): Other possible arguments that a loss function might need """ # TODO: Tune CHUNK_SIZE to fully utilize the GPU @@ -94,6 +96,7 @@ def forward( use_ref_model=use_ref_model, ref_weight=ref_weight, ref_bias=ref_bias, + average_log_prob=average_log_prob, **loss_kwargs, ) @@ -265,6 +268,7 @@ def chunk_forward( bias=None, ignore_index=-100, compute_nll_loss=True, + average_log_prob=True, ): len_chosen_chunk = target_chunk.shape[0] // 2 logits_chunk = input_chunk @ weight.t() @@ -285,10 +289,13 @@ def chunk_forward( label_chunk = torch.where(loss_mask, target_chunk, 0) per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1) - average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + if average_log_prob: + log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + log_prob = (per_token_logps * loss_mask).sum(-1) - chosen_logps = average_log_prob[:len_chosen_chunk] - rejected_logps = average_log_prob[len_chosen_chunk:] + chosen_logps = log_prob[:len_chosen_chunk] + rejected_logps = log_prob[len_chosen_chunk:] chosen_logits = logits_chunk[:len_chosen_chunk] rejected_logits = logits_chunk[len_chosen_chunk:] @@ -317,6 +324,7 @@ def _compute_loss( ref_input_chunk=None, ref_weight=None, ref_bias=None, + average_log_prob=True, **loss_kwargs, ): """ @@ -335,6 +343,7 @@ def _compute_loss( use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + average_log_prob (bool): Whether to average log probabilities or the sum. loss_kwargs (dict): Additional arguments for the loss function. """ ( @@ -350,6 +359,7 @@ def _compute_loss( bias=bias, ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, + average_log_prob=average_log_prob, ) chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]) @@ -372,6 +382,7 @@ def _compute_loss( ref_bias, ignore_index=ignore_index, compute_nll_loss=False, # We don't need NLL loss for the reference model + average_log_prob=average_log_prob, ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index a59e3db19..c996e57f9 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -103,9 +103,10 @@ def __init__( label_smoothing=label_smoothing, simpo_gamma=simpo_gamma, ).get_batch_loss_metrics + self.average_log_prob = loss_type == "simpo" def forward(self, x, y): - return self.cpo_loss(self.lin.weight, x, y, self.lin.bias) + return self.cpo_loss(self.lin.weight, x, y, self.lin.bias, average_log_prob=self.average_log_prob) class LigerLMHeadCPO(torch.nn.Module): @@ -143,7 +144,7 @@ def forward(self, x, y): @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.bfloat16, 5e-2, 5e-2), (1.0, torch.float32, 1e-5, 5e-4), ], )