-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add KTO Loss #475
base: main
Are you sure you want to change the base?
Add KTO Loss #475
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Take a brief look, I am not very familiar with KTO math but why do we not have KL_log_probs but original HF has https://github.com/huggingface/trl/blob/cd7156fb34ddf9a8c04fcd640a4067933461d44e/trl/trainer/kto_trainer.py#L1121. We also need to be careful about scaling. Seems in original HF, kto_loss
returns an unreduced version, but we probably need to reduce as mean. cc @shivam15s
About KL, I'll take a further look in About |
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> ### KTO LOSS #### Memory ![image](https://github.com/user-attachments/assets/bd8fe4f6-0c18-4cf3-a79a-fc8634dcb492) #### Speed ![image](https://github.com/user-attachments/assets/256cf0c3-3943-4f46-b256-38a577323a03) <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
AMD Test failed due to no gpu available, not related to the PR: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution. Did a first pass over the functionality and left some comments.
preference_labels_chunk=None, | ||
ref_input_chunk=None, | ||
): | ||
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss) = fused_fwd_bwd( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is it *chunk_grad_bias
and not chunk_grad_bias
like the other gradients?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The *chunk_grad_bias
syntax is used here because of how Python handles unpacking in this situation. Let me explain:
-
The
fused_fwd_bwd
function returns a tuple of two elements:- First element: A tuple of gradients (either 2 or 3 elements depending on if bias is used)
- Second element: The loss value
-
When
bias
is None,fused_fwd_bwd
returns only two gradients:(chunk_grad_input, chunk_grad_weight)
When
bias
is not None, it returns three:(chunk_grad_input, chunk_grad_weight, chunk_grad_bias)
-
The
*
operator is used to handle this variability - it collects any remaining elements into a list. So:- If bias is None:
*chunk_grad_bias
becomes an empty list[]
- If bias exists:
*chunk_grad_bias
becomes a list with one element[chunk_grad_bias]
- If bias is None:
src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py
Outdated
Show resolved
Hide resolved
src/liger_kernel/chunked_loss/fused_linear_unpaired_preference.py
Outdated
Show resolved
Hide resolved
""" | ||
Compute the total loss for a chunk of input and target, while using an alignment/preference loss function. | ||
Args: | ||
preference_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like this class already has a staticmethod for preference_loss_fn
. Why do we need an extra arg here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
The abstract
preference_loss_fn
method defined at the class level is just a placeholder that defines the interface - notice it's marked with@abstractmethod
and raisesNotImplementedError
. This is meant to be implemented by subclasses. -
The
preference_loss_fn
parameter in the methods likeforward()
and_compute_loss()
is the actual function that will be used to compute the loss. This is passed in at runtime.
This is a common pattern in Python where the base class defines an interface (abstract method) but allows for runtime flexibility by accepting the actual implementation as a parameter. This makes the class more flexible as you can:
- Create subclasses that implement a default preference loss function by overriding the abstract method
- Override that default at runtime by passing in a different loss function as a parameter
Summary
Close KTO Item of the Roadmap: #371
Implements the Kahneman-Tversky Optimization (KTO) loss function.
KTO Loss Function
For a policy π compared to a reference policy π₀:
When y is chosen:
When y is rejected:
where:
Intuition
KTO loss is inspired by prospect theory from behavioral economics, which models how humans make decisions under uncertainty.
The loss function is asymmetric, treating gains and losses differently, similar to
human decision-making patterns.
Credit by: https://www.youtube.com/watch?v=nSrj1J6ODoM&t=422s
Benchmark Result
Special thanks to @shivam15s on the optimization PR: #491, otherwise my implementation won't achieve speed as list below
Memory:
Speed:
Notable learning on optimizing the speed:
Key Changes
LigerFusedLinearKTOLoss
classLigerFusedLinearKTOFunction
for the core KTO computationtest_kto_loss.py
HFKTOLoss
) based on Hugging Face's implementationReference
Testing Done
Test is passing now:
pytest test/chunked_loss/test_kto_loss.py
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence