Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KTO Loss #475

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open

Add KTO Loss #475

wants to merge 27 commits into from

Conversation

hebiao064
Copy link
Collaborator

@hebiao064 hebiao064 commented Dec 13, 2024

Summary

Close KTO Item of the Roadmap: #371

Implements the Kahneman-Tversky Optimization (KTO) loss function.

KTO Loss Function

For a policy π compared to a reference policy π₀:

When y is chosen:

$L_{KTO} = 1 - \sigma(\beta \cdot (\log[\frac{\pi(x)}{\pi_0(x)}] - KL(\pi||\pi_0)_y))$

When y is rejected:

$L_{KTO} = 1 - \sigma(\beta \cdot (KL(\pi||\pi_0)_y - \log[\frac{\pi(x)}{\pi_0(x)}]))$

where:

  • σ is the sigmoid function
  • β is a temperature parameter
  • KL(π||π₀)_y is the KL divergence threshold for action y

Intuition

KTO loss is inspired by prospect theory from behavioral economics, which models how humans make decisions under uncertainty.

The loss function is asymmetric, treating gains and losses differently, similar to
human decision-making patterns.

Screenshot 2024-12-13 at 11 10 39 AM

Credit by: https://www.youtube.com/watch?v=nSrj1J6ODoM&t=422s

Benchmark Result

Special thanks to @shivam15s on the optimization PR: #491, otherwise my implementation won't achieve speed as list below

Memory:

image

Speed:
image

Notable learning on optimizing the speed:

  • [Culprit] Repeated calculation of KL when we split to N chunks
  • [Good to have] Remove the unnecessary variables calculation like aux_outputs

Key Changes

  • Implemented LigerFusedLinearKTOLoss class
  • Added LigerFusedLinearKTOFunction for the core KTO computation
  • Created comprehensive test suite in test_kto_loss.py
  • Added reference implementation (HFKTOLoss) based on Hugging Face's implementation

Reference

Testing Done

Test is passing now:
pytest test/chunked_loss/test_kto_loss.py

  • Parameterized tests covering various configurations:
    • Different batch sizes, sequence lengths, hidden dims, and vocab sizes
    • Multiple data types (bfloat16, float32)
    • Bias and reference bias variations
    • Different ignore indices and beta values
  • Correctness tests comparing against reference implementation
  • Gradient checking and backward pass verification
  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@hebiao064 hebiao064 marked this pull request as ready for review December 13, 2024 01:41
Copy link
Collaborator

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take a brief look, I am not very familiar with KTO math but why do we not have KL_log_probs but original HF has https://github.com/huggingface/trl/blob/cd7156fb34ddf9a8c04fcd640a4067933461d44e/trl/trainer/kto_trainer.py#L1121. We also need to be careful about scaling. Seems in original HF, kto_loss returns an unreduced version, but we probably need to reduce as mean. cc @shivam15s

@hebiao064
Copy link
Collaborator Author

Take a brief look, I am not very familiar with KTO math but why do we not have KL_log_probs but original HF has https://github.com/huggingface/trl/blob/cd7156fb34ddf9a8c04fcd640a4067933461d44e/trl/trainer/kto_trainer.py#L1121. We also need to be careful about scaling. Seems in original HF, kto_loss returns an unreduced version, but we probably need to reduce as mean. cc @shivam15s

About KL, I'll take a further look in trl about how to support that.

About reduce, HF did averaged it here: loss = losses.nanmean()

hebiao064 and others added 11 commits December 16, 2024 21:34
## Summary
<!--- This is a required section; please describe the main purpose of
this proposed code change. --->
### KTO LOSS
#### Memory

![image](https://github.com/user-attachments/assets/bd8fe4f6-0c18-4cf3-a79a-fc8634dcb492)
#### Speed

![image](https://github.com/user-attachments/assets/256cf0c3-3943-4f46-b256-38a577323a03)

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: <BLANK>
- [ ] run `make test` to ensure correctness
- [ ] run `make checkstyle` to ensure code style
- [ ] run `make test-convergence` to ensure convergence
@hebiao064 hebiao064 enabled auto-merge (squash) December 21, 2024 06:50
@hebiao064
Copy link
Collaborator Author

AMD Test failed due to no gpu available, not related to the PR: FAILED test/transformers/test_swiglu.py::test_correctness_functional[dtype1-10000.0-0.01-9-7-41] - RuntimeError: No HIP GPUs are available

Copy link
Collaborator

@kvignesh1420 kvignesh1420 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution. Did a first pass over the functionality and left some comments.

benchmark/scripts/benchmark_kto_loss.py Outdated Show resolved Hide resolved
preference_labels_chunk=None,
ref_input_chunk=None,
):
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss) = fused_fwd_bwd(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it *chunk_grad_bias and not chunk_grad_bias like the other gradients?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The *chunk_grad_bias syntax is used here because of how Python handles unpacking in this situation. Let me explain:

  1. The fused_fwd_bwd function returns a tuple of two elements:

    • First element: A tuple of gradients (either 2 or 3 elements depending on if bias is used)
    • Second element: The loss value
  2. When bias is None, fused_fwd_bwd returns only two gradients: (chunk_grad_input, chunk_grad_weight)

    When bias is not None, it returns three: (chunk_grad_input, chunk_grad_weight, chunk_grad_bias)

  3. The * operator is used to handle this variability - it collects any remaining elements into a list. So:

    • If bias is None: *chunk_grad_bias becomes an empty list []
    • If bias exists: *chunk_grad_bias becomes a list with one element [chunk_grad_bias]

"""
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function.
Args:
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this class already has a staticmethod for preference_loss_fn. Why do we need an extra arg here?

Copy link
Collaborator Author

@hebiao064 hebiao064 Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The abstract preference_loss_fn method defined at the class level is just a placeholder that defines the interface - notice it's marked with @abstractmethod and raises NotImplementedError. This is meant to be implemented by subclasses.

  2. The preference_loss_fn parameter in the methods like forward() and _compute_loss() is the actual function that will be used to compute the loss. This is passed in at runtime.

This is a common pattern in Python where the base class defines an interface (abstract method) but allows for runtime flexibility by accepting the actual implementation as a parameter. This makes the class more flexible as you can:

  1. Create subclasses that implement a default preference loss function by overriding the abstract method
  2. Override that default at runtime by passing in a different loss function as a parameter

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants