Skip to content

Commit

Permalink
Merge branch 'main' into kto_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
hebiao064 authored Jan 15, 2025
2 parents 6d33947 + ba72b8e commit 26f48d0
Show file tree
Hide file tree
Showing 15 changed files with 155 additions and 122 deletions.
23 changes: 8 additions & 15 deletions benchmark/scripts/benchmark_cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction
from liger_kernel.utils import infer_device

device = infer_device()
Expand All @@ -27,7 +26,8 @@
def bench_memory_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO, TorchLMHeadCPO
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO
from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -36,12 +36,8 @@ def bench_memory_fused_linear_cpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down Expand Up @@ -72,7 +68,8 @@ def full():
def bench_speed_fused_linear_cpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO, TorchLMHeadCPO
from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO
from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -82,12 +79,8 @@ def bench_speed_fused_linear_cpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down
8 changes: 4 additions & 4 deletions benchmark/scripts/benchmark_dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,13 @@
import torch
import triton

from test.chunked_loss.test_dpo_loss import HF_DPO_Loss
from utils import QUANTILES
from utils import SingleBenchmarkRunInput
from utils import SingleBenchmarkRunOutput
from utils import _test_memory
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction
from liger_kernel.utils import infer_device

device = infer_device()
Expand All @@ -21,7 +19,8 @@


def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO, TorchLMHeadDPO
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO
from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand Down Expand Up @@ -70,7 +69,8 @@ def full():


def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO, TorchLMHeadDPO
from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO
from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand Down
33 changes: 14 additions & 19 deletions benchmark/scripts/benchmark_orpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction
from liger_kernel.utils import infer_device

device = infer_device()
Expand All @@ -27,7 +26,8 @@
def bench_memory_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO, TorchLMHeadORPO
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO
from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -36,21 +36,18 @@ def bench_memory_fused_linear_orpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_orpo(_input, target)
return liger_lm_head_orpo(_input, target, nll_target)
elif provider == "huggingface":
return torch_lm_head_orpo(_input, target)
return torch_lm_head_orpo(_input, target, nll_target)

def full():
y = fwd()
Expand All @@ -72,7 +69,8 @@ def full():
def bench_speed_fused_linear_orpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO, TorchLMHeadORPO
from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO
from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -82,21 +80,18 @@ def bench_speed_fused_linear_orpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device)

def fwd():
if provider == "liger":
return liger_lm_head_orpo(_input, target)
return liger_lm_head_orpo(_input, target, nll_target)
elif provider == "huggingface":
return torch_lm_head_orpo(_input, target)
return torch_lm_head_orpo(_input, target, nll_target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
Expand Down
23 changes: 8 additions & 15 deletions benchmark/scripts/benchmark_simpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from utils import parse_benchmark_script_args
from utils import run_benchmarks

from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction
from liger_kernel.utils import infer_device

device = infer_device()
Expand All @@ -27,7 +26,8 @@
def bench_memory_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO, TorchLMHeadCPO
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO
from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -36,12 +36,8 @@ def bench_memory_fused_linear_simpo_loss(
dtype = input.extra_benchmark_config["dtype"]
provider = input.kernel_provider

torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down Expand Up @@ -72,7 +68,8 @@ def full():
def bench_speed_fused_linear_simpo_loss(
input: SingleBenchmarkRunInput,
) -> SingleBenchmarkRunOutput:
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO, TorchLMHeadCPO
from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO
from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO

B = input.x
T = input.extra_benchmark_config["T"]
Expand All @@ -82,12 +79,8 @@ def bench_speed_fused_linear_simpo_loss(
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(
device
)(x, target)[0]
torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]
liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0]

_input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device)
target = torch.randint(V, (B, T), dtype=torch.long, device=device)
Expand Down
7 changes: 6 additions & 1 deletion src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, labe
loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / (
full_target.shape[0] // 2
)
return loss

chosen_rewards = beta * chosen_logps
rejected_rewards = beta * rejected_logps

return loss, chosen_rewards, rejected_rewards

@staticmethod
def forward(
Expand Down Expand Up @@ -61,6 +65,7 @@ def forward(
beta=beta,
label_smoothing=label_smoothing,
compute_nll_loss=compute_nll_loss,
average_log_prob=False,
compiled=compiled,
)

Expand Down
Loading

0 comments on commit 26f48d0

Please sign in to comment.