diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index d8d3a5315..e9cc6ee24 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -71,10 +71,11 @@ def _compute_loss( weight_hard_loss=0.5, weight_soft_loss=0.5, compute_ce_loss=True, + temperature=1, **loss_kwargs, ): """ - Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. + Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function. Args: distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). @@ -84,11 +85,12 @@ def _compute_loss( target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). + full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,). ignore_index (int): Index to ignore for loss computation. weight_hard_loss (float): Weight for hard loss. weight_soft_loss (float): Weight for soft loss. compute_ce_loss (bool): Whether to compute CE loss. + temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale) loss_kwargs (dict): Additional arguments for the loss function. """ ( @@ -107,6 +109,9 @@ def _compute_loss( compute_ce_loss=compute_ce_loss, ) + student_logits_chunk /= temperature + teacher_logits_chunk /= temperature + hard_loss /= full_target.shape[0] soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk) @@ -171,6 +176,7 @@ def forward( weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, compute_ce_loss=compute_ce_loss, + temperature=temperature, **loss_kwargs, ) @@ -225,9 +231,6 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): if compiled: accumulate_chunk = torch.compile(accumulate_chunk) - student_input /= temperature - teacher_input /= temperature - num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE) _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0) _teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0) diff --git a/src/liger_kernel/chunked_loss/jsd_loss.py b/src/liger_kernel/chunked_loss/jsd_loss.py index d3500b713..931514a80 100644 --- a/src/liger_kernel/chunked_loss/jsd_loss.py +++ b/src/liger_kernel/chunked_loss/jsd_loss.py @@ -1,26 +1,20 @@ import torch import torch.nn.functional as F -from liger_kernel.chunked_loss.fused_linear_distillation import ( - LigerFusedLinearDistillationBase, -) +from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase): @staticmethod - def distillation_loss_fn(student_logits, teacher_logits, temperature): + def distillation_loss_fn(student_logits, teacher_logits): """ Compute JSD loss (Jensen-Shannon Divergence Loss). Args: - student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size,). - teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size,). - temperature (float): Temperature for softening probability distributions + student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,). + teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,). Returns: torch.Tensor: Jensen-Shannon Divergence loss """ - # Scale logits by temperature - student_logits = student_logits / temperature - teacher_logits = teacher_logits / temperature # Convert to probabilities student_probs = F.softmax(student_logits, dim=-1) teacher_probs = F.softmax(teacher_logits, dim=-1) @@ -30,13 +24,13 @@ def distillation_loss_fn(student_logits, teacher_logits, temperature): student_kl = F.kl_div( log_mean_probs, torch.log(student_probs), - reduction="batchmean", + reduction="sum", log_target=True, ) teacher_kl = F.kl_div( log_mean_probs, torch.log(teacher_probs), - reduction="batchmean", + reduction="sum", log_target=True, ) @@ -69,7 +63,7 @@ def forward( weight_hard_loss (float): Weight for hard loss. weight_soft_loss (float): Weight for soft loss. ignore_index (int): Index to ignore in loss computation - temperature (float): Temperature for softening distributions + temperature (float): Temperature for softening/sharpening distributions compiled (bool): Whether to use torch compile Returns: torch.Tensor: Computed loss diff --git a/test/chunked_loss/test_jsd_loss.py b/test/chunked_loss/test_jsd_loss.py index 4fd0cff35..96644a482 100644 --- a/test/chunked_loss/test_jsd_loss.py +++ b/test/chunked_loss/test_jsd_loss.py @@ -1,5 +1,3 @@ -from test.utils import HFDistillationLoss, assert_verbose_allclose, set_seed - import pytest import torch import torch.nn.functional as F @@ -8,6 +6,9 @@ from liger_kernel.chunked_loss.functional import liger_fused_linear_jsd from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction from liger_kernel.utils import infer_device +from test.utils import HFDistillationLoss +from test.utils import assert_verbose_allclose +from test.utils import set_seed device = infer_device() @@ -31,21 +32,18 @@ def __init__( ignore_index=ignore_index, weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, + temperature=temperature, ) - self.temperature = temperature def distillation_loss(self, student_logits, teacher_logits): """ Compute JSD loss (Jensen-Shannon Divergence Loss). Args: - student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len,). - teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len,). + student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,). + teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,). Returns: torch.Tensor: Jensen-Shannon Divergence loss """ - # Scale logits by temperature - student_logits = student_logits / self.temperature - teacher_logits = teacher_logits / self.temperature # Convert to probabilities student_probs = F.softmax(student_logits, dim=-1) teacher_probs = F.softmax(teacher_logits, dim=-1) @@ -92,21 +90,16 @@ def __init__( ): super().__init__() # smaller student model weights - self.student_lin = torch.nn.Linear( - in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device - ) - self.teacher_lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype, device=device - ) + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) self.jsd = HFJSDLoss( ignore_index=ignore_index, weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, + temperature=temperature, ).get_batch_loss_metrics - self.temperature = temperature def forward(self, student_input, teacher_input, target): - jsd_loss = self.jsd( student_input, self.student_lin.weight, @@ -131,18 +124,14 @@ def __init__( ): super().__init__() # smaller student model weights - self.student_lin = torch.nn.Linear( - in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device - ) - self.teacher_lin = torch.nn.Linear( - in_features=H, out_features=V, bias=False, dtype=dtype, device=device - ) + self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device) + self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device) self.chunked_jsd = LigerFusedLinearJSDLoss( weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, ignore_index=ignore_index, + temperature=temperature, ) - self.temperature = temperature def forward(self, student_input, teacher_input, target): return self.chunked_jsd( @@ -169,15 +158,16 @@ def forward(self, student_input, teacher_input, target): @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - (1.0, torch.bfloat16, 5e-2, 5e-2), - (1.0, torch.float32, 1e-4, 5e-3), + (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.float32, 1e-5, 5e-4), ], ) @pytest.mark.parametrize( "temperature, weight_hard_loss, weight_soft_loss", [ (1.0, 0.5, 0.5), - (2.0, 0.1, 0.9), + (2.0, 0.5, 0.5), + (0.7, 0.5, 0.5), (1.0, 0.0, 1.0), (1.0, 1.0, 0.0), ], @@ -189,11 +179,11 @@ def test_correctness( V, scalar, dtype, - weight_hard_loss, - weight_soft_loss, - temperature, atol, rtol, + temperature, + weight_hard_loss, + weight_soft_loss, ): torch_lm_head_jsd = TorchLMHeadJSD( H=H, @@ -214,12 +204,12 @@ def test_correctness( weight_soft_loss=weight_soft_loss, ) - torch_lm_head_jsd.student_lin.weight.data = ( - liger_lm_head_jsd.student_lin.weight.data - ) = torch.rand(V, H // 2, device=device, dtype=dtype) - torch_lm_head_jsd.teacher_lin.weight.data = ( - liger_lm_head_jsd.teacher_lin.weight.data - ) = torch.rand(V, H, device=device, dtype=dtype) + torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand( + V, H // 2, device=device, dtype=dtype + ) + torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand( + V, H, device=device, dtype=dtype + ) _tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar student_input1 = _tensor.detach().clone().requires_grad_(True) @@ -236,9 +226,7 @@ def test_correctness( loss1.backward() loss2.backward() - assert_verbose_allclose( - student_input1.grad, student_input2.grad, atol=atol, rtol=rtol - ) + assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol) assert_verbose_allclose( torch_lm_head_jsd.student_lin.weight.grad, @@ -320,10 +308,6 @@ def test_correctness_functional( output1.backward() output2.backward() - assert_verbose_allclose( - student_input1.grad, student_input2.grad, atol=atol, rtol=rtol - ) + assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol) - assert_verbose_allclose( - student_weight1.grad, student_weight2.grad, atol=atol, rtol=rtol - ) + assert_verbose_allclose(student_weight1.grad, student_weight2.grad, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index ec4abd5a8..a1fd22150 100644 --- a/test/utils.py +++ b/test/utils.py @@ -605,7 +605,10 @@ def get_batch_loss_metrics( hard_loss, ) = forward_output + student_logits /= self.temperature + teacher_logits /= self.temperature + soft_loss = self.distillation_loss(student_logits, teacher_logits) # full loss - loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss.mean() + loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss return loss