Skip to content

Commit

Permalink
expose beta
Browse files Browse the repository at this point in the history
Signed-off-by: Austin Liu <[email protected]>
  • Loading branch information
austin362667 committed Jan 16, 2025
1 parent 3a1ba6b commit a79a814
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 9 deletions.
3 changes: 3 additions & 0 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def forward(
ignore_index=-100,
weight_hard_loss=0.5,
weight_soft_loss=0.5,
beta=0.5,
compute_ce_loss=True,
temperature=1.0,
compiled=True,
Expand All @@ -157,6 +158,7 @@ def forward(
ignore_index (int): Index to ignore for loss computation.
weight_hard_loss (float): Weight for hard/task loss.
weight_soft_loss (float): Weight for soft/distillation loss.
beta (float): Interpolation coefficient between 0 and 1 (default: 0.5).
compute_ce_loss (bool): Whether to compute CE loss.
temperature (float): Temperature to control the input probability distribution. Default: `1.0` (i.e. no scale)
compiled (bool): Whether to use torch compile for chunk accumulation.
Expand All @@ -175,6 +177,7 @@ def forward(
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
beta=beta,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
**loss_kwargs,
Expand Down
9 changes: 8 additions & 1 deletion src/liger_kernel/chunked_loss/jsd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def forward(
true_labels: torch.LongTensor,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
compiled: bool = True,
Expand All @@ -54,6 +55,7 @@ def forward(
true_labels (torch.LongTensor): Target tensor. Shape: (batch_size * seq_len,)
weight_hard_loss (float): Weight for hard loss.
weight_soft_loss (float): Weight for soft loss.
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
ignore_index (int): Index to ignore in loss computation
temperature (float): Temperature for softening/sharpening distributions
compiled (bool): Whether to use torch compile
Expand All @@ -71,6 +73,7 @@ def forward(
chunk_size=1,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
beta=beta,
ignore_index=ignore_index,
temperature=temperature,
compiled=compiled,
Expand All @@ -80,7 +83,7 @@ def forward(
def backward(ctx, grad_output):
grads = LigerFusedLinearDistillationBase.backward(ctx, grad_output)[:4]

return (*grads, None, None, None, None, None, None)
return (*grads, None, None, None, None, None, None, None)


class LigerFusedLinearJSDLoss(torch.nn.Module):
Expand All @@ -92,6 +95,7 @@ def __init__(
self,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
compiled: bool = True,
Expand All @@ -103,6 +107,7 @@ def __init__(
ignore_index (int): Index to ignore in the loss
temperature (float): Temperature for softening distributions
compiled (bool): Whether to use torch compile
beta (float): Coefficient beta of generalized JSD in the interval [0, 1]. Default: `0.5`.
"""
super().__init__()
assert temperature != 0, "Temperature cannot be 0."
Expand All @@ -111,6 +116,7 @@ def __init__(
self.ignore_index = ignore_index
self.temperature = temperature
self.compiled = compiled
self.beta = beta

def forward(
self,
Expand Down Expand Up @@ -141,6 +147,7 @@ def forward(
true_labels,
self.weight_hard_loss,
self.weight_soft_loss,
self.beta,
self.ignore_index,
self.temperature,
self.compiled,
Expand Down
27 changes: 19 additions & 8 deletions test/chunked_loss/test_jsd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ def __init__(
ignore_index: int = -100,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
beta: float = 0.5,
):
super().__init__(
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
)
self.beta = (beta,)

def distillation_loss(self, student_logits, teacher_logits, beta=0.5):
"""
Expand Down Expand Up @@ -77,6 +79,7 @@ def __init__(
device: torch.device,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
Expand All @@ -89,6 +92,7 @@ def __init__(
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
beta=beta,
).get_batch_loss_metrics

def forward(self, student_input, teacher_input, target):
Expand All @@ -111,6 +115,7 @@ def __init__(
device: torch.device,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
beta: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
):
Expand Down Expand Up @@ -155,13 +160,13 @@ def forward(self, student_input, teacher_input, target):
],
)
@pytest.mark.parametrize(
"temperature, weight_hard_loss, weight_soft_loss",
"temperature, weight_hard_loss, weight_soft_loss, beta",
[
(1.0, 0.5, 0.5),
(2.0, 0.5, 0.5),
(0.5, 0.5, 0.5),
(1.0, 0.0, 1.0),
(1.0, 1.0, 0.0),
(1.0, 0.5, 0.5, 0.5),
(2.0, 0.5, 0.5, 0.5),
(0.5, 0.5, 0.5, 0.5),
(1.0, 0.0, 1.0, 0.5),
(1.0, 1.0, 0.0, 0.5),
],
)
def test_correctness(
Expand All @@ -176,6 +181,7 @@ def test_correctness(
temperature,
weight_hard_loss,
weight_soft_loss,
beta,
):
torch_lm_head_jsd = TorchLMHeadJSD(
H=H,
Expand All @@ -185,6 +191,7 @@ def test_correctness(
temperature=temperature,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
beta=beta,
)
liger_lm_head_jsd = LigerLMHeadJSD(
H=H,
Expand All @@ -194,6 +201,7 @@ def test_correctness(
temperature=temperature,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
beta=beta,
)

torch_lm_head_jsd.student_lin.weight.data = liger_lm_head_jsd.student_lin.weight.data = torch.rand(
Expand Down Expand Up @@ -243,8 +251,8 @@ def test_correctness(
],
)
@pytest.mark.parametrize(
"temperature, weight_hard_loss, weight_soft_loss, ignore_index",
[(1.0, 0.5, 0.5, -100), (2.0, 0.1, 0.9, 42)],
"temperature, weight_hard_loss, weight_soft_loss, beta, ignore_index",
[(1.0, 0.5, 0.5, 0.5, -100), (2.0, 0.1, 0.9, 0.5, 42)],
)
def test_correctness_functional(
B,
Expand All @@ -255,6 +263,7 @@ def test_correctness_functional(
dtype,
weight_hard_loss,
weight_soft_loss,
beta,
ignore_index,
temperature,
atol,
Expand All @@ -280,6 +289,7 @@ def test_correctness_functional(
label,
weight_hard_loss,
weight_soft_loss,
beta,
ignore_index,
temperature,
)
Expand All @@ -291,6 +301,7 @@ def test_correctness_functional(
label,
weight_hard_loss,
weight_soft_loss,
beta,
ignore_index,
temperature,
)
Expand Down

0 comments on commit a79a814

Please sign in to comment.