diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index e9cc6ee24..f157ffd51 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -135,6 +135,7 @@ def forward( ignore_index=-100, weight_hard_loss=0.5, weight_soft_loss=0.5, + beta=0.5, compute_ce_loss=True, temperature=1.0, compiled=True, @@ -157,6 +158,7 @@ def forward( ignore_index (int): Index to ignore for loss computation. weight_hard_loss (float): Weight for hard/task loss. weight_soft_loss (float): Weight for soft/distillation loss. + beta (float): Interpolation coefficient between 0 and 1 (default: 0.5). compute_ce_loss (bool): Whether to compute CE loss. temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale) compiled (bool): Whether to use torch compile for chunk accumulation. @@ -175,6 +177,7 @@ def forward( ignore_index=ignore_index, weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, + beta=beta, compute_ce_loss=compute_ce_loss, temperature=temperature, **loss_kwargs, diff --git a/src/liger_kernel/chunked_loss/jsd_loss.py b/src/liger_kernel/chunked_loss/jsd_loss.py index e7c2d9dbd..92cbb4777 100644 --- a/src/liger_kernel/chunked_loss/jsd_loss.py +++ b/src/liger_kernel/chunked_loss/jsd_loss.py @@ -40,6 +40,7 @@ def forward( true_labels: torch.LongTensor, weight_hard_loss: float = 0.5, weight_soft_loss: float = 0.5, + beta: float = 0.5, ignore_index: int = -100, temperature: float = 1.0, compiled: bool = True, @@ -54,6 +55,7 @@ def forward( true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,) weight_hard_loss (float): Weight for hard loss. weight_soft_loss (float): Weight for soft loss. + beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`. ignore_index (int): Index to ignore in loss computation temperature (float): Temperature for softening/sharpening distributions compiled (bool): Whether to use torch compile @@ -71,6 +73,7 @@ def forward( chunk_size=1, weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, + beta=beta, ignore_index=ignore_index, temperature=temperature, compiled=compiled, @@ -80,7 +83,7 @@ def forward( def backward(ctx, grad_output): grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4] - return (*grads, None, None, None, None, None, None) + return (*grads, None, None, None, None, None, None, None) class LigerFusedLinearJSDLoss(torch.nn.Module): @@ -92,6 +95,7 @@ def __init__( self, weight_hard_loss: float = 0.5, weight_soft_loss: float = 0.5, + beta: float = 0.5, ignore_index: int = -100, temperature: float = 1.0, compiled: bool = True, @@ -103,6 +107,7 @@ def __init__( ignore_index (int): Index to ignore in the loss temperature (float): Temperature for softening distributions compiled (bool): Whether to use torch compile + beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`. """ super().__init__() assert temperature != 0, "Temperature cannot be 0." @@ -111,6 +116,7 @@ def __init__( self.ignore_index = ignore_index self.temperature = temperature self.compiled = compiled + self.beta = beta def forward( self, @@ -141,6 +147,7 @@ def forward( true_labels, self.weight_hard_loss, self.weight_soft_loss, + self.beta, self.ignore_index, self.temperature, self.compiled, diff --git a/test/chunked_loss/test_jsd_loss.py b/test/chunked_loss/test_jsd_loss.py index 92cd44c8a..5d51cd5b4 100644 --- a/test/chunked_loss/test_jsd_loss.py +++ b/test/chunked_loss/test_jsd_loss.py @@ -27,6 +27,7 @@ def __init__( ignore_index: int = -100, weight_hard_loss: float = 0.5, weight_soft_loss: float = 0.5, + beta: float = 0.5, ): super().__init__( ignore_index=ignore_index, @@ -34,6 +35,7 @@ def __init__( weight_soft_loss=weight_soft_loss, temperature=temperature, ) + self.beta = (beta,) def distillation_loss(self, student_logits, teacher_logits, beta=0.5): """ @@ -77,6 +79,7 @@ def __init__( device: torch.device, weight_hard_loss: float = 0.5, weight_soft_loss: float = 0.5, + beta: float = 0.5, ignore_index: int = -100, temperature: float = 1.0, ): @@ -89,6 +92,7 @@ def __init__( weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, temperature=temperature, + beta=beta, ).get_batch_loss_metrics def forward(self, student_input, teacher_input, target): @@ -111,6 +115,7 @@ def __init__( device: torch.device, weight_hard_loss: float = 0.5, weight_soft_loss: float = 0.5, + beta: float = 0.5, ignore_index: int = -100, temperature: float = 1.0, ): @@ -155,13 +160,13 @@ def forward(self, student_input, teacher_input, target): ], ) @pytest.mark.parametrize( - "temperature, weight_hard_loss, weight_soft_loss", + "temperature, weight_hard_loss, weight_soft_loss, beta", [ - (1.0, 0.5, 0.5), - (2.0, 0.5, 0.5), - (0.5, 0.5, 0.5), - (1.0, 0.0, 1.0), - (1.0, 1.0, 0.0), + (1.0, 0.5, 0.5, 0.5), + (2.0, 0.5, 0.5, 0.5), + (0.5, 0.5, 0.5, 0.5), + (1.0, 0.0, 1.0, 0.5), + (1.0, 1.0, 0.0, 0.5), ], ) def test_correctness( @@ -176,6 +181,7 @@ def test_correctness( temperature, weight_hard_loss, weight_soft_loss, + beta, ): torch_lm_head_jsd = TorchLMHeadJSD( H=H, @@ -185,6 +191,7 @@ def test_correctness( temperature=temperature, weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, + beta=beta, ) liger_lm_head_jsd = LigerLMHeadJSD( H=H, @@ -194,6 +201,7 @@ def test_correctness( temperature=temperature, weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, + beta=beta, ) torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( @@ -243,8 +251,8 @@ def test_correctness( ], ) @pytest.mark.parametrize( - "temperature, weight_hard_loss, weight_soft_loss, ignore_index", - [(1.0, 0.5, 0.5, -100), (2.0, 0.1, 0.9, 42)], + "temperature, weight_hard_loss, weight_soft_loss, beta, ignore_index", + [(1.0, 0.5, 0.5, 0.5, -100), (2.0, 0.1, 0.9, 0.5, 42)], ) def test_correctness_functional( B, @@ -255,6 +263,7 @@ def test_correctness_functional( dtype, weight_hard_loss, weight_soft_loss, + beta, ignore_index, temperature, atol, @@ -280,6 +289,7 @@ def test_correctness_functional( label, weight_hard_loss, weight_soft_loss, + beta, ignore_index, temperature, ) @@ -291,6 +301,7 @@ def test_correctness_functional( label, weight_hard_loss, weight_soft_loss, + beta, ignore_index, temperature, )