Skip to content

Commit

Permalink
Fix/liger fused linear cross entropy function does not support reduct…
Browse files Browse the repository at this point in the history
…ion=none (#496)

## Summary
fix #488 
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: A100-40G-PCIe
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence
  • Loading branch information
ryankert01 authored Dec 28, 2024
1 parent 77ff1a9 commit 9875488
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 5 deletions.
1 change: 0 additions & 1 deletion src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, labe
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
full_target.shape[0] // 2
)

return loss

@staticmethod
Expand Down
7 changes: 5 additions & 2 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,16 @@ def fused_linear_cross_entropy_forward(
alpha=alpha,
)

loss = torch.sum(loss_1d)
if reduction == "none":
loss = loss_1d
else:
loss = torch.sum(loss_1d)
return loss, grad_input, grad_weight, grad_bias


def fused_linear_cross_entropy_backward(grad_output, grad_input, grad_weight, grad_bias):
# If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time
if torch.ne(grad_output, torch.tensor(1.0, device=grad_output.device)):
if not torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):
# We use a Triton kernel instead of a PyTorch operation because modifying inputs in-place
# for gradient storage and backward multiple times causes anomalies with PyTorch but not with Triton.
BT, H = grad_input.shape
Expand Down
6 changes: 4 additions & 2 deletions test/transformers/test_fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def forward(self, x, y):
("mean", 1.0, torch.float32, 1e-5, 5e-4),
("sum", 1.0, torch.bfloat16, 5e-0, 5e1),
("sum", 1.0, torch.float32, 1e-3, 5e-2),
("none", 1.0, torch.bfloat16, 5e-0, 5e1),
("none", 1.0, torch.float32, 1e-3, 5e-2),
],
)
@pytest.mark.parametrize("bias", [True, False])
Expand Down Expand Up @@ -185,8 +187,8 @@ def test_correctness(

assert_verbose_allclose(output1, output2, atol=atol, rtol=rtol)

output1.backward()
output2.backward()
output1.backward(gradient=torch.ones_like(output1))
output2.backward(gradient=torch.ones_like(output2))

assert_verbose_allclose(_input1.grad, _input2.grad, atol=atol, rtol=rtol)

Expand Down

0 comments on commit 9875488

Please sign in to comment.