Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add average_log_prob args for cpo #510

Merged
merged 7 commits into from
Jan 8, 2025

Conversation

Mecoli1219
Copy link
Contributor

@Mecoli1219 Mecoli1219 commented Jan 3, 2025

Summary

trl CPO implementation didn't average the log_probs, while the liger kernel averages it when computing the loss. This will cause a mismatch when integrating them.

Testing Done

  • Updating unit test (still investigating why unit test fail locally)
  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Signed-off-by: Mecoli1219 <[email protected]>
Signed-off-by: Mecoli1219 <[email protected]>
@kashif
Copy link
Contributor

kashif commented Jan 3, 2025

TRL is using the default as in the official repo for CPO: https://github.com/fe1ixxu/CPO_SIMPO/blob/main/scripts/cpo_trainer.py#L626

@@ -139,7 +140,7 @@ def forward(self, x, y):
@pytest.mark.parametrize(
"scalar, dtype, atol, rtol",
[
(1.0, torch.bfloat16, 5e-3, 5e-3),
(1.0, torch.bfloat16, 5e-2, 5e-2),
Copy link
Collaborator

@austin362667 austin362667 Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the reasoning behind this adjustment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kashif and I find that after disabling average_log_prob for CPO, it will have a higher deviation from HF implementation when the model is large and the data type is bf16. Since the result is still close within both methods, we increase atol and rtol to make this test pass.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as bfloat16 is less accurate for larger numbers, this is needed to make the test pass and is the same as in the other bfloat16 tests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then adjusting tol makes sense. ❤️

Copy link
Collaborator

@austin362667 austin362667 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you both for making this PR. Hopefully, it unblocks huggingface/trl#2506.

@austin362667 austin362667 merged commit 23e3772 into linkedin:main Jan 8, 2025
2 of 5 checks passed
@kashif
Copy link
Contributor

kashif commented Jan 8, 2025

awesome thank you! we would still need a release of liger-kernel for the CI to pass but yes it will hopefully unblock!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants