-
Notifications
You must be signed in to change notification settings - Fork 243
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add average_log_prob
args for cpo
#510
Conversation
Signed-off-by: Mecoli1219 <[email protected]>
Signed-off-by: Mecoli1219 <[email protected]>
Signed-off-by: Mecoli1219 <[email protected]>
TRL is using the default as in the official repo for CPO: https://github.com/fe1ixxu/CPO_SIMPO/blob/main/scripts/cpo_trainer.py#L626 |
Co-authored-by: Kashif Rasul <[email protected]>
Signed-off-by: Mecoli1219 <[email protected]>
@@ -139,7 +140,7 @@ def forward(self, x, y): | |||
@pytest.mark.parametrize( | |||
"scalar, dtype, atol, rtol", | |||
[ | |||
(1.0, torch.bfloat16, 5e-3, 5e-3), | |||
(1.0, torch.bfloat16, 5e-2, 5e-2), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the reasoning behind this adjustment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kashif and I find that after disabling average_log_prob
for CPO, it will have a higher deviation from HF implementation when the model is large and the data type is bf16
. Since the result is still close within both methods, we increase atol
and rtol
to make this test pass.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as bfloat16 is less accurate for larger numbers, this is needed to make the test pass and is the same as in the other bfloat16 tests
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then adjusting tol makes sense. ❤️
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you both for making this PR. Hopefully, it unblocks huggingface/trl#2506.
awesome thank you! we would still need a release of liger-kernel for the CI to pass but yes it will hopefully unblock! |
Summary
trl
CPO implementation didn't average the log_probs, while the liger kernel averages it when computing the loss. This will cause a mismatch when integrating them.Testing Done
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence