Skip to content

Commit

Permalink
refactor temperature scaling and fix chunked distillation loss norm r…
Browse files Browse the repository at this point in the history
…eduction

Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Jan 7, 2025
1 parent cf84632 commit 87a7f8c
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 63 deletions.
13 changes: 8 additions & 5 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,11 @@ def _compute_loss(
weight_hard_loss=0.5,
weight_soft_loss=0.5,
compute_ce_loss=True,
temperature=1,
**loss_kwargs,
):
"""
Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function.
Compute the total loss for a chunk of input and target, while using an knowledge distillation loss function.
Args:
distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size).
Expand All @@ -84,11 +85,12 @@ def _compute_loss(
target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,).
student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,).
full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,).
full_target (torch.Tensor): Full target tensor. Shape: (batch_size * sequence_length,).
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard loss.
weight_soft_loss (float): Weight for soft loss.
compute_ce_loss (bool): Whether to compute CE loss.
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
loss_kwargs (dict): Additional arguments for the loss function.
"""
(
Expand All @@ -107,6 +109,9 @@ def _compute_loss(
compute_ce_loss=compute_ce_loss,
)

student_logits_chunk /= temperature
teacher_logits_chunk /= temperature

hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
Expand Down Expand Up @@ -171,6 +176,7 @@ def forward(
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
**loss_kwargs,
)

Expand Down Expand Up @@ -225,9 +231,6 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk):
if compiled:
accumulate_chunk = torch.compile(accumulate_chunk)

student_input /= temperature
teacher_input /= temperature

num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0)
_teacher_input_chunks = torch.chunk(teacher_input, chunks=num_chunks, dim=0)
Expand Down
20 changes: 7 additions & 13 deletions src/liger_kernel/chunked_loss/jsd_loss.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
import torch
import torch.nn.functional as F

from liger_kernel.chunked_loss.fused_linear_distillation import (
LigerFusedLinearDistillationBase,
)
from liger_kernel.chunked_loss.fused_linear_distillation import LigerFusedLinearDistillationBase


class LigerFusedLinearJSDFunction(LigerFusedLinearDistillationBase):
@staticmethod
def distillation_loss_fn(student_logits, teacher_logits, temperature):
def distillation_loss_fn(student_logits, teacher_logits):
"""
Compute JSD loss (Jensen-Shannon Divergence Loss).
Args:
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size,).
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size,).
temperature (float): Temperature for softening probability distributions
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
Returns:
torch.Tensor: Jensen-Shannon Divergence loss
"""
# Scale logits by temperature
student_logits = student_logits / temperature
teacher_logits = teacher_logits / temperature
# Convert to probabilities
student_probs = F.softmax(student_logits, dim=-1)
teacher_probs = F.softmax(teacher_logits, dim=-1)
Expand All @@ -30,13 +24,13 @@ def distillation_loss_fn(student_logits, teacher_logits, temperature):
student_kl = F.kl_div(
log_mean_probs,
torch.log(student_probs),
reduction="batchmean",
reduction="sum",
log_target=True,
)
teacher_kl = F.kl_div(
log_mean_probs,
torch.log(teacher_probs),
reduction="batchmean",
reduction="sum",
log_target=True,
)

Expand Down Expand Up @@ -69,7 +63,7 @@ def forward(
weight_hard_loss (float): Weight for hard loss.
weight_soft_loss (float): Weight for soft loss.
ignore_index (int): Index to ignore in loss computation
temperature (float): Temperature for softening distributions
temperature (float): Temperature for softening/sharpening distributions
compiled (bool): Whether to use torch compile
Returns:
torch.Tensor: Computed loss
Expand Down
72 changes: 28 additions & 44 deletions test/chunked_loss/test_jsd_loss.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from test.utils import HFDistillationLoss, assert_verbose_allclose, set_seed

import pytest
import torch
import torch.nn.functional as F
Expand All @@ -8,6 +6,9 @@
from liger_kernel.chunked_loss.functional import liger_fused_linear_jsd
from liger_kernel.chunked_loss.jsd_loss import LigerFusedLinearJSDFunction
from liger_kernel.utils import infer_device
from test.utils import HFDistillationLoss
from test.utils import assert_verbose_allclose
from test.utils import set_seed

device = infer_device()

Expand All @@ -31,21 +32,18 @@ def __init__(
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
)
self.temperature = temperature

def distillation_loss(self, student_logits, teacher_logits):
"""
Compute JSD loss (Jensen-Shannon Divergence Loss).
Args:
student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len,).
teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len,).
student_logits (torch.Tensor): Logits of student tokens. Shape: (batch_size * seq_len,).
teacher_logits (torch.Tensor): Logits of teacher tokens. Shape: (batch_size * seq_len,).
Returns:
torch.Tensor: Jensen-Shannon Divergence loss
"""
# Scale logits by temperature
student_logits = student_logits / self.temperature
teacher_logits = teacher_logits / self.temperature
# Convert to probabilities
student_probs = F.softmax(student_logits, dim=-1)
teacher_probs = F.softmax(teacher_logits, dim=-1)
Expand Down Expand Up @@ -92,21 +90,16 @@ def __init__(
):
super().__init__()
# smaller student model weights
self.student_lin = torch.nn.Linear(
in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device
)
self.teacher_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype, device=device
)
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device)
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
self.jsd = HFJSDLoss(
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
).get_batch_loss_metrics
self.temperature = temperature

def forward(self, student_input, teacher_input, target):

jsd_loss = self.jsd(
student_input,
self.student_lin.weight,
Expand All @@ -131,18 +124,14 @@ def __init__(
):
super().__init__()
# smaller student model weights
self.student_lin = torch.nn.Linear(
in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device
)
self.teacher_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=False, dtype=dtype, device=device
)
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=False, dtype=dtype, device=device)
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=False, dtype=dtype, device=device)
self.chunked_jsd = LigerFusedLinearJSDLoss(
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
ignore_index=ignore_index,
temperature=temperature,
)
self.temperature = temperature

def forward(self, student_input, teacher_input, target):
return self.chunked_jsd(
Expand All @@ -169,15 +158,16 @@ def forward(self, student_input, teacher_input, target):
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-2, 5e-2),
(1.0, torch.float32, 1e-4, 5e-3),
(1.0, torch.bfloat16, 5e-3, 5e-3),
(1.0, torch.float32, 1e-5, 5e-4),
],
)
@pytest.mark.parametrize(
"temperature, weight_hard_loss, weight_soft_loss",
[
(1.0, 0.5, 0.5),
(2.0, 0.1, 0.9),
(2.0, 0.5, 0.5),
(0.7, 0.5, 0.5),
(1.0, 0.0, 1.0),
(1.0, 1.0, 0.0),
],
Expand All @@ -189,11 +179,11 @@ def test_correctness(
V,
scalar,
dtype,
weight_hard_loss,
weight_soft_loss,
temperature,
atol,
rtol,
temperature,
weight_hard_loss,
weight_soft_loss,
):
torch_lm_head_jsd = TorchLMHeadJSD(
H=H,
Expand All @@ -214,12 +204,12 @@ def test_correctness(
weight_soft_loss=weight_soft_loss,
)

torch_lm_head_jsd.student_lin.weight.data = (
liger_lm_head_jsd.student_lin.weight.data
) = torch.rand(V, H // 2, device=device, dtype=dtype)
torch_lm_head_jsd.teacher_lin.weight.data = (
liger_lm_head_jsd.teacher_lin.weight.data
) = torch.rand(V, H, device=device, dtype=dtype)
torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand(
V, H // 2, device=device, dtype=dtype
)
torch_lm_head_jsd.teacher_lin.weight.data = liger_lm_head_jsd.teacher_lin.weight.data = torch.rand(
V, H, device=device, dtype=dtype
)

_tensor = torch.rand(B * T, H // 2, device=device, dtype=dtype) * scalar
student_input1 = _tensor.detach().clone().requires_grad_(True)
Expand All @@ -236,9 +226,7 @@ def test_correctness(
loss1.backward()
loss2.backward()

assert_verbose_allclose(
student_input1.grad, student_input2.grad, atol=atol, rtol=rtol
)
assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol)

assert_verbose_allclose(
torch_lm_head_jsd.student_lin.weight.grad,
Expand Down Expand Up @@ -320,10 +308,6 @@ def test_correctness_functional(
output1.backward()
output2.backward()

assert_verbose_allclose(
student_input1.grad, student_input2.grad, atol=atol, rtol=rtol
)
assert_verbose_allclose(student_input1.grad, student_input2.grad, atol=atol, rtol=rtol)

assert_verbose_allclose(
student_weight1.grad, student_weight2.grad, atol=atol, rtol=rtol
)
assert_verbose_allclose(student_weight1.grad, student_weight2.grad, atol=atol, rtol=rtol)
5 changes: 4 additions & 1 deletion test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,7 +605,10 @@ def get_batch_loss_metrics(
hard_loss,
) = forward_output

student_logits /= self.temperature
teacher_logits /= self.temperature

soft_loss = self.distillation_loss(student_logits, teacher_logits)
# full loss
loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss.mean()
loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss
return loss

0 comments on commit 87a7f8c

Please sign in to comment.