Skip to content

Commit

Permalink
resolve conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
ByronHsu committed Dec 23, 2024
1 parent 22428a2 commit 323fcef
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 41 deletions.
11 changes: 4 additions & 7 deletions src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):
@staticmethod
def preference_loss_fn(
chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0
):
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, label_smoothing=0.0):
"""
Paper: https://arxiv.org/pdf/2401.08417
Expand All @@ -32,10 +30,9 @@ def preference_loss_fn(
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
"""
logits = beta * (chosen_logps - rejected_logps)
loss = (
- F.logsigmoid(logits) * (1 - label_smoothing)
- F.logsigmoid(-logits) * label_smoothing
).sum() / (full_target.shape[0] // 2)
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
full_target.shape[0] // 2
)

return loss

Expand Down
7 changes: 3 additions & 4 deletions src/liger_kernel/chunked_loss/simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,9 @@ def preference_loss_fn(
label_smoothing (float): Label smoothing factor, will reduce to Equation above when label_smoothing -> 0.
"""
logits = beta * (chosen_logps - rejected_logps) - gamma
loss = (
- F.logsigmoid(logits) * (1 - label_smoothing)
- F.logsigmoid(-logits) * label_smoothing
).sum() / (full_target.shape[0] // 2)
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
full_target.shape[0] // 2
)

return loss

Expand Down
8 changes: 2 additions & 6 deletions test/chunked_loss/test_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,7 @@ def __init__(
label_smoothing: float = 0.0,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.cpo_loss = LigerFusedLinearCPOLoss(
ignore_index=ignore_index,
beta=beta,
Expand All @@ -146,9 +144,7 @@ def forward(self, x, y):
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize(
"ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)]
)
@pytest.mark.parametrize("ignore_index, beta, alpha", [(-100, 0.1, 1.0), (42, 0.2, 0.85)])
@pytest.mark.parametrize("label_smoothing", [0.0, 0.1])
def test_correctness(
B,
Expand Down
20 changes: 5 additions & 15 deletions test/chunked_loss/test_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,8 @@ def __init__(
beta: float = 0.1,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ref_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=ref_bias, dtype=dtype
)
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype)
self.dpo_loss = HFDPOLoss(
ignore_index=ignore_index,
beta=beta,
Expand Down Expand Up @@ -112,12 +108,8 @@ def __init__(
beta: float = 0.1,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ref_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=ref_bias, dtype=dtype
)
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.ref_lin = torch.nn.Linear(in_features=H, out_features=V, bias=ref_bias, dtype=dtype)
self.dpo_loss = LigerFusedLinearDPOLoss(
ignore_index=ignore_index,
beta=beta,
Expand Down Expand Up @@ -279,9 +271,7 @@ def test_correctness(
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("ref_bias", [True, False])
@pytest.mark.parametrize("compute_nll_loss", [True, False])
def test_correctness_functional(
B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, compute_nll_loss
):
def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias, ref_bias, compute_nll_loss):
B = 2 * B

_input = torch.randn(B, T, H, device=device, dtype=dtype) * scalar
Expand Down
8 changes: 2 additions & 6 deletions test/chunked_loss/test_simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,7 @@ def __init__(
gamma: float = 0.5,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype)
self.simpo_loss = LigerFusedLinearSimPOLoss(
ignore_index=ignore_index,
beta=beta,
Expand Down Expand Up @@ -59,9 +57,7 @@ def forward(self, x, y):
],
)
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize(
"ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)]
)
@pytest.mark.parametrize("ignore_index, beta, gamma", [(-100, 0.1, 0.5), (42, 0.2, 0.85)])
@pytest.mark.parametrize("label_smoothing", [0.0, 0.1])
def test_correctness(
B,
Expand Down
4 changes: 1 addition & 3 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,9 +433,7 @@ def cross_entropy_loss(logits, labels):
labels = target
chosen_nll_loss = torch.tensor(0.0, device=all_logits.device)
if self.compute_nll_loss:
chosen_nll_loss = cross_entropy_loss(
all_logits[:len_chosen], labels[:len_chosen]
)
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])

all_logps = self.get_batch_logps(
all_logits,
Expand Down

0 comments on commit 323fcef

Please sign in to comment.