Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/liger fused linear cross entropy function does not support reduction=none #496

Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def preference_loss_fn(
"""
logits = beta * (chosen_logps - rejected_logps)
loss = (
- F.logsigmoid(logits) * (1 - label_smoothing)
-F.logsigmoid(logits) * (1 - label_smoothing)
- F.logsigmoid(-logits) * label_smoothing
).sum() / (full_target.shape[0] // 2)

Expand Down
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def preference_loss_fn(
"""
logits = beta * (chosen_logps - rejected_logps) - gamma
loss = (
- F.logsigmoid(logits) * (1 - label_smoothing)
-F.logsigmoid(logits) * (1 - label_smoothing)
- F.logsigmoid(-logits) * label_smoothing
).sum() / (full_target.shape[0] // 2)

Expand Down
7 changes: 5 additions & 2 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,15 +133,18 @@ def fused_linear_cross_entropy_forward(
alpha=alpha,
)

loss = torch.sum(loss_1d)
if reduction == "none":
loss = loss_1d
else:
loss = torch.sum(loss_1d)
return loss, grad_input, grad_weight, grad_bias


def fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
BT, H = grad_input.shape
Expand Down
6 changes: 4 additions & 2 deletions test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def forward(self, x, y):
("mean", 1.0, torch.float32, 1e-5, 5e-4),
("sum", 1.0, torch.bfloat16, 5e-0, 5e1),
("sum", 1.0, torch.float32, 1e-3, 5e-2),
("none", 1.0, torch.bfloat16, 5e-0, 5e1),
("none", 1.0, torch.float32, 1e-3, 5e-2),
],
)
@pytest.mark.parametrize("bias", [True, False])
Expand Down Expand Up @@ -197,8 +199,8 @@ def test_correctness(

assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)

output1.backward()
output2.backward()
output1.backward(gradient=torch.ones_like(output1))
output2.backward(gradient=torch.ones_like(output2))

assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)

Expand Down
Loading