diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 2b8052e25..dd84a4dbf 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -9,7 +9,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase): @staticmethod - def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): + def preference_loss_fn( + chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0 + ): """ Paper: https://arxiv.org/pdf/2401.08417 @@ -30,9 +32,14 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). full_target (torch.Tensor): Non chunked full target tensor beta (float): Weight for the CPO loss + label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0. """ logits = beta * (chosen_logps - rejected_logps) - loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) + loss = ( + - F.logsigmoid(logits) * (1 - label_smoothing) + - F.logsigmoid(-logits) * label_smoothing + ).sum() / (full_target.shape[0] // 2) + return loss @staticmethod @@ -45,6 +52,7 @@ def forward( ignore_index=-100, beta=0.1, alpha=1.0, + label_smoothing=0.0, compute_nll_loss=True, compiled=True, ): @@ -58,6 +66,7 @@ def forward( ignore_index=ignore_index, alpha=alpha, beta=beta, + label_smoothing=label_smoothing, compute_nll_loss=compute_nll_loss, compiled=compiled, ) @@ -65,7 +74,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None + return *grads, None, None, None, None, None, None class LigerFusedLinearCPOLoss(torch.nn.Module): @@ -78,6 +87,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, compute_nll_loss: bool = True, compiled: bool = True, ): @@ -90,6 +100,7 @@ def __init__( self.ignore_index = ignore_index self.beta = beta self.alpha = alpha + self.label_smoothing = label_smoothing self.compute_nll_loss = compute_nll_loss self.compiled = compiled @@ -102,6 +113,7 @@ def forward(self, lin_weight, _input, target, bias=None): self.ignore_index, self.beta, self.alpha, + self.label_smoothing, self.compute_nll_loss, self.compiled, ) diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index 5f1b17cf5..cf07e186e 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -64,7 +64,7 @@ def forward( ref_bias=None, ignore_index=-100, beta=0.1, - compute_nll_loss=True, + compute_nll_loss=False, compiled=True, use_ref_model=True, ): @@ -100,7 +100,7 @@ def __init__( self, ignore_index: int = -100, beta: float = 0.1, - compute_nll_loss: bool = True, + compute_nll_loss: bool = False, compiled: bool = True, use_ref_model: bool = False, ): diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index fff0791ec..4eb939a79 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -408,7 +408,7 @@ def _compute_loss( else: preference_loss, aux_outputs = preference_loss_outputs, [] - loss = alpha * chosen_nll_loss - preference_loss + loss = alpha * chosen_nll_loss + preference_loss return_vars = ( chosen_logps, rejected_logps, diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index c860d4bd9..d615212c5 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -36,7 +36,7 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): - torch.log1p(-torch.exp(rejected_logps)) ) ratio = F.logsigmoid(log_odds) - loss = beta * ratio.sum() / (full_target.shape[0] // 2) + loss = -beta * ratio.sum() / (full_target.shape[0] // 2) chosen_rewards = beta * chosen_logps rejected_rewards = beta * rejected_logps diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index 7efa0603d..5d5867252 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -10,7 +10,12 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): @staticmethod def preference_loss_fn( - chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5 + chosen_logps, + rejected_logps, + full_target, + beta=0.1, + gamma=0.5, + label_smoothing=0.0, ): """ Paper: https://arxiv.org/pdf/2405.14734 @@ -33,9 +38,14 @@ def preference_loss_fn( full_target: Non chunked full target tensor beta (float): beta weight gamma (float): gemma margin term + label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0. """ logits = beta * (chosen_logps - rejected_logps) - gamma - loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) + loss = ( + - F.logsigmoid(logits) * (1 - label_smoothing) + - F.logsigmoid(-logits) * label_smoothing + ).sum() / (full_target.shape[0] // 2) + return loss @staticmethod @@ -48,6 +58,7 @@ def forward( ignore_index=-100, beta=0.1, alpha=1.0, + label_smoothing=0.0, compute_nll_loss=False, compiled=True, gamma=0.5, @@ -63,6 +74,7 @@ def forward( ignore_index=ignore_index, alpha=alpha, beta=beta, + label_smoothing=label_smoothing, compiled=compiled, gamma=gamma, ) @@ -70,7 +82,7 @@ def forward( @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None, None, None + return *grads, None, None, None, None, None, None, None class LigerFusedLinearSimPOLoss(torch.nn.Module): @@ -83,6 +95,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, compute_nll_loss: bool = True, compiled: bool = True, gamma: float = 0.5, @@ -96,6 +109,7 @@ def __init__( self.ignore_index = ignore_index self.beta = beta self.alpha = alpha + self.label_smoothing = label_smoothing self.compute_nll_loss = compute_nll_loss self.compiled = compiled self.gamma = gamma @@ -109,6 +123,7 @@ def forward(self, lin_weight, _input, target, bias=None): self.ignore_index, self.beta, self.alpha, + self.label_smoothing, self.compute_nll_loss, self.compiled, self.gamma, diff --git a/src/liger_kernel/transformers/model/mixtral.py b/src/liger_kernel/transformers/model/mixtral.py index 22fea53da..145bc78cd 100644 --- a/src/liger_kernel/transformers/model/mixtral.py +++ b/src/liger_kernel/transformers/model/mixtral.py @@ -38,7 +38,7 @@ def lce_forward_deprecated( cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, MoeCausalLMOutputWithPast]: r""" - Copy paste Mixtral's forward from transfomers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy + Copy paste Mixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy Args: diff --git a/src/liger_kernel/transformers/trainer/orpo_trainer.py b/src/liger_kernel/transformers/trainer/orpo_trainer.py index 3605b9f1b..04391fa5f 100644 --- a/src/liger_kernel/transformers/trainer/orpo_trainer.py +++ b/src/liger_kernel/transformers/trainer/orpo_trainer.py @@ -17,7 +17,7 @@ class _FSDPForwardRedirection: This is needed in cases where we call a submodule of a FSDP module. For instance, when we want to call only the `LlamaModel` part out of a FSDP-wrapped `LlamaForCausalLM` to get the hidden states without involving GPU-memory-heavy `lm_head` and cross entropy computation, doing this directly (i.e. `model.model.forward()`) - will not work because the first `nn.Emebedding` layer is not independently wrapped as a FSDP module (because of + will not work because the first `nn.Embedding` layer is not independently wrapped as a FSDP module (because of the transformer-based wrapping policy), and not calling it through FSDP root module forward will not all-gather its parameter, thus resulting in "RuntimeError: 'weight' must be 2-D" error. Similarly, if we want to call just the `lm_head` part of a model, we need this trick too to properly get its params all-gathered. @@ -125,6 +125,10 @@ def orpo_partial(lm_head, last_hidden_state, concatenated_labels): outputs.last_hidden_state, concatenated_batch["concatenated_labels"], ) + # if aux_loss_enabled, add the aux_loss to the orpo_loss + if self.aux_loss_enabled: + orpo_loss += self.aux_loss_coef * outputs.aux_loss + return orpo_loss, aux_outputs def get_batch_loss_metrics( diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index f0fef7734..4090db795 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -60,14 +60,14 @@ def alignment_loss( if self.loss_type == "sigmoid": # This reduces to Equation 3 from the CPO paper when label_smoothing -> 0. losses = ( - F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - + F.logsigmoid(-self.beta * logits) * self.label_smoothing + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing ) elif self.loss_type == "simpo": logits = logits - (self.simpo_gamma / self.beta) losses = ( - F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) - + F.logsigmoid(-self.beta * logits) * self.label_smoothing + -F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing) + - F.logsigmoid(-self.beta * logits) * self.label_smoothing ) else: raise ValueError( @@ -86,6 +86,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, loss_type: str = "sigmoid", simpo_gamma: float = 0.5, ): @@ -97,6 +98,7 @@ def __init__( ignore_index=ignore_index, beta=beta, loss_type=loss_type, + label_smoothing=label_smoothing, simpo_gamma=simpo_gamma, ).get_batch_loss_metrics @@ -114,13 +116,17 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, ): super().__init__() self.lin = torch.nn.Linear( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.cpo_loss = LigerFusedLinearCPOLoss( - ignore_index=ignore_index, beta=beta, alpha=alpha + ignore_index=ignore_index, + beta=beta, + alpha=alpha, + label_smoothing=label_smoothing, ) def forward(self, x, y): @@ -145,8 +151,21 @@ def forward(self, x, y): @pytest.mark.parametrize( "ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)] ) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, alpha + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ignore_index, + beta, + alpha, + label_smoothing, ): B = 2 * B # cpo loss requires B to be even @@ -157,6 +176,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, ) liger_lm_head_cpo = LigerLMHeadCPO( H=H, @@ -165,6 +185,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, ) torch_lm_head_cpo.lin.weight.data = liger_lm_head_cpo.lin.weight.data = torch.randn( diff --git a/test/chunked_loss/test_dpo_loss.py b/test/chunked_loss/test_dpo_loss.py index 0ac8faeb8..b73a69a57 100644 --- a/test/chunked_loss/test_dpo_loss.py +++ b/test/chunked_loss/test_dpo_loss.py @@ -23,10 +23,17 @@ class HFDPOLoss(HFAlignmentLoss): """ def __init__( - self, ignore_index: int = -100, beta: float = 0.1, use_ref_model: bool = True + self, + ignore_index: int = -100, + beta: float = 0.1, + use_ref_model: bool = True, + compute_nll_loss: bool = False, ): super().__init__( - beta=beta, ignore_index=ignore_index, use_ref_model=use_ref_model + beta=beta, + ignore_index=ignore_index, + use_ref_model=use_ref_model, + compute_nll_loss=compute_nll_loss, ) def alignment_loss( @@ -61,6 +68,7 @@ def __init__( dtype: torch.dtype, bias: bool = False, ref_bias: bool = False, + compute_nll_loss: bool = False, ignore_index: int = -100, beta: float = 0.1, ): @@ -72,7 +80,10 @@ def __init__( in_features=H, out_features=V, bias=ref_bias, dtype=dtype ) self.dpo_loss = HFDPOLoss( - ignore_index=ignore_index, beta=beta, use_ref_model=True + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, ).get_batch_loss_metrics def forward(self, x, ref_x, y): @@ -95,6 +106,7 @@ def __init__( dtype: torch.dtype, bias: bool = False, ref_bias: bool = False, + compute_nll_loss: bool = False, ignore_index: int = -100, beta: float = 0.1, ): @@ -106,7 +118,10 @@ def __init__( in_features=H, out_features=V, bias=ref_bias, dtype=dtype ) self.dpo_loss = LigerFusedLinearDPOLoss( - ignore_index=ignore_index, beta=beta, use_ref_model=True + ignore_index=ignore_index, + beta=beta, + use_ref_model=True, + compute_nll_loss=compute_nll_loss, ) def forward(self, x, ref_x, y): @@ -132,14 +147,27 @@ def forward(self, x, ref_x, y): "scalar, dtype, atol, rtol", [ (1.0, torch.bfloat16, 5e-2, 5e-1), - (1.0, torch.float32, 2e-2, 5e-1), + (1.0, torch.float32, 1e-5, 5e-4), ], ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("ref_bias", [True, False]) +@pytest.mark.parametrize("compute_nll_loss", [True, False]) @pytest.mark.parametrize("ignore_index, beta", [(-100, 0.1), (42, 0.2)]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, ignore_index, beta + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ref_bias, + compute_nll_loss, + ignore_index, + beta, ): B = 2 * B # dpo loss requires B to be even @@ -149,6 +177,7 @@ def test_correctness( dtype=dtype, bias=bias, ref_bias=ref_bias, + compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, ) @@ -158,6 +187,7 @@ def test_correctness( dtype=dtype, bias=bias, ref_bias=ref_bias, + compute_nll_loss=compute_nll_loss, ignore_index=ignore_index, beta=beta, ) @@ -251,7 +281,10 @@ def test_correctness( ) @pytest.mark.parametrize("bias", [True, False]) @pytest.mark.parametrize("ref_bias", [True, False]) -def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias): +@pytest.mark.parametrize("compute_nll_loss", [True, False]) +def test_correctness_functional( + B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, compute_nll_loss +): B = 2 * B _input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar @@ -290,10 +323,28 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref ref_bias2 = _ref_bias.detach().clone().requires_grad_(True) if ref_bias else None loss1, aggregated_aux_outputs1 = LigerFusedLinearDPOFunction.apply( - input1, weight1, target, bias1, ref_input, ref_weight1, ref_bias1 + input1, + weight1, + target, + bias1, + ref_input, + ref_weight1, + ref_bias1, + -100, + 0.1, + compute_nll_loss, ) loss2, aggregated_aux_outputs2 = liger_fused_linear_dpo( - input2, weight2, target, bias2, ref_input, ref_weight2, ref_bias2 + input2, + weight2, + target, + bias2, + ref_input, + ref_weight2, + ref_bias2, + -100, + 0.1, + compute_nll_loss, ) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/chunked_loss/test_orpo_loss.py b/test/chunked_loss/test_orpo_loss.py index 9f5d81b18..112d4f05c 100644 --- a/test/chunked_loss/test_orpo_loss.py +++ b/test/chunked_loss/test_orpo_loss.py @@ -57,7 +57,7 @@ def alignment_loss( - torch.log1p(-torch.exp(policy_rejected_logps)) ) ratio = F.logsigmoid(log_odds) - losses = self.beta * ratio + losses = -self.beta * ratio chosen_rewards = self.beta * policy_chosen_logps rejected_rewards = self.beta * policy_rejected_logps diff --git a/test/chunked_loss/test_simpo_loss.py b/test/chunked_loss/test_simpo_loss.py index 3d0937c27..eede598fe 100644 --- a/test/chunked_loss/test_simpo_loss.py +++ b/test/chunked_loss/test_simpo_loss.py @@ -25,6 +25,7 @@ def __init__( ignore_index: int = -100, beta: float = 0.1, alpha: float = 1.0, + label_smoothing: float = 0.0, gamma: float = 0.5, ): super().__init__() @@ -32,7 +33,11 @@ def __init__( in_features=H, out_features=V, bias=bias, dtype=dtype ) self.simpo_loss = LigerFusedLinearSimPOLoss( - ignore_index=ignore_index, beta=beta, alpha=alpha, gamma=gamma + ignore_index=ignore_index, + beta=beta, + alpha=alpha, + gamma=gamma, + label_smoothing=label_smoothing, ) def forward(self, x, y): @@ -57,8 +62,21 @@ def forward(self, x, y): @pytest.mark.parametrize( "ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)] ) +@pytest.mark.parametrize("label_smoothing", [0.0, 0.1]) def test_correctness( - B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, beta, gamma + B, + T, + H, + V, + scalar, + dtype, + atol, + rtol, + bias, + ignore_index, + beta, + gamma, + label_smoothing, ): B = 2 * B # SimPO loss requires B to be even @@ -70,6 +88,7 @@ def test_correctness( ignore_index=ignore_index, beta=beta, loss_type="simpo", + label_smoothing=label_smoothing, simpo_gamma=gamma, ) liger_lm_head_simpo = LigerLMHeadSimPO( @@ -79,6 +98,7 @@ def test_correctness( bias=bias, ignore_index=ignore_index, beta=beta, + label_smoothing=label_smoothing, gamma=gamma, ) diff --git a/test/utils.py b/test/utils.py index 3d3799ad0..3d08c4ae3 100644 --- a/test/utils.py +++ b/test/utils.py @@ -350,11 +350,13 @@ def __init__( beta: float = 0.1, ignore_index: int = -100, use_ref_model: bool = False, + compute_nll_loss: bool = True, ): self.alpha = alpha self.beta = beta self.ignore_index = ignore_index self.use_ref_model = use_ref_model + self.compute_nll_loss = compute_nll_loss @abstractmethod def alignment_loss(self): @@ -448,9 +450,11 @@ def cross_entropy_loss(logits, labels): return loss labels = target - chosen_nll_loss = cross_entropy_loss( - all_logits[:len_chosen], labels[:len_chosen] - ) + chosen_nll_loss = torch.tensor(0.0, device=all_logits.device) + if self.compute_nll_loss: + chosen_nll_loss = cross_entropy_loss( + all_logits[:len_chosen], labels[:len_chosen] + ) all_logps = self.get_batch_logps( all_logits, @@ -511,7 +515,7 @@ def get_batch_loss_metrics( else: losses, aggregated_aux_outputs = alignment_loss_outputs, [] # full loss - loss = policy_nll_loss * self.alpha - losses.mean() + loss = policy_nll_loss * self.alpha + losses.mean() return_vars = ( policy_chosen_logps, policy_rejected_logps,