Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KTO Loss #475

Open
wants to merge 27 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
4471ba6
Add KTO Loss
hebiao064 Dec 12, 2024
d3f565b
Fix Tests
hebiao064 Dec 13, 2024
ab08ab6
formatting
hebiao064 Dec 13, 2024
98aa519
Add more docstrings
hebiao064 Dec 13, 2024
6e67869
Fix tests
hebiao064 Dec 13, 2024
3a76c76
Add Benchmark Result
hebiao064 Dec 13, 2024
7114a19
Merge branch 'main' into kto_loss
hebiao064 Dec 13, 2024
f46a595
Merge branch 'main' into kto_loss
hebiao064 Dec 17, 2024
1992153
Add KL, Unpair flag, Preference Labels
hebiao064 Dec 18, 2024
3aebf8d
Reorder Preference Labels and Bias
hebiao064 Dec 19, 2024
1dd10a7
Fix Tests
hebiao064 Dec 19, 2024
130e909
Fix all tests
hebiao064 Dec 20, 2024
c2fab1d
make it random number for chosen
hebiao064 Dec 20, 2024
fe7eafc
Merge branch 'main' into kto_loss
hebiao064 Dec 20, 2024
eeaa570
Update benchmark
hebiao064 Dec 20, 2024
85a582e
speed up kto loss and some refactor (#495)
shivam15s Dec 21, 2024
6d50d44
Merge branch 'main' into kto_loss
hebiao064 Dec 21, 2024
846dc2e
Change sign of loss to align with merged changes
hebiao064 Dec 21, 2024
33fa548
Add KL into KTO Test
hebiao064 Dec 23, 2024
f7b29d5
Add KL and Benchmark
hebiao064 Dec 23, 2024
29e818e
Fix the speed slow down by removing .item() which would incur gpu-cpu…
hebiao064 Dec 23, 2024
c478e75
Merge branch 'main' into kto_loss
hebiao064 Dec 23, 2024
06b2350
Fix checkstyle
hebiao064 Dec 23, 2024
3cf3771
Remove unnecessary change from conflict merge
hebiao064 Dec 23, 2024
6d33947
Merge branch 'main' into kto_loss
hebiao064 Jan 3, 2025
26f48d0
Merge branch 'main' into kto_loss
hebiao064 Jan 15, 2025
71b1773
Fix comments
hebiao064 Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ site/
.venv/
venv/
.ipynb_checkpoints/
.vscode/

# Misc
.DS_Store
Expand Down
30 changes: 30 additions & 0 deletions benchmark/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
## Benchmarking Liger Kernels

Follow these steps to benchmark and visualize kernel performance:

1. Create a benchmark script
- Add your script under `benchmark/scripts/`
- Name it according to the kernel (e.g., `benchmark_<kernel_name>.py`)

2. Run the benchmark
- Results will be saved to `benchmark/data/all_benchmark_data.csv`

Example: Benchmarking KTO Loss
```bash
cd benchmark
python scripts/benchmark_kto_loss.py
```

3. Visualize results
- Use the visualization script with appropriate parameters

Example: Visualizing KTO Loss benchmark results
```bash
python benchmarks_visualizer.py \
--kernel-name kto_loss \
--metric-name memory \
--kernel-operation-mode full
```

4. View results
- Generated plots will be saved in `benchmark/visualizations/`
30 changes: 30 additions & 0 deletions benchmark/data/all_benchmark_data.csv
Original file line number Diff line number Diff line change
Expand Up @@ -715,3 +715,33 @@ fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,2,8645.314453125,8645.314
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,4,12184.330078125,12184.330078125,12184.330078125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,8,19262.361328125,19262.361328125,19262.361328125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
fused_linear_simpo_loss,huggingface,full,memory,MB,B,B,16,33418.42578125,33418.42578125,33418.42578125,"{""T"": 1024, ""H"": 4096, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16""}",NVIDIA A100-SXM4-80GB,2024-11-15 14:30:01,0.4.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),2,8.2532958984375,8.235372543334961,8.274937629699707,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),4,16.888959884643555,16.879615783691406,16.898893356323242,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),8,32.13854217529297,32.12795639038086,32.149131774902344,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),16,64.81161499023438,64.81161499023438,64.81161499023438,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,liger,forward,speed,ms,B,Batch Size (B),32,128.68646240234375,128.68646240234375,128.68646240234375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:29,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),2,7.146656036376953,7.143622398376465,7.152345657348633,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),4,12.538240432739258,12.521356582641602,12.540371894836426,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),8,26.29542350769043,25.303590774536133,26.88591957092285,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),16,49.26508712768555,49.26508712768555,49.26508712768555,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,huggingface,forward,speed,ms,B,Batch Size (B),32,98.9525146484375,98.9525146484375,98.9525146484375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:05:49,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),2,9.005151748657227,8.97766399383545,9.046483039855957,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),4,19.108863830566406,19.09713363647461,19.185260772705078,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),8,32.80137634277344,32.775360107421875,32.827388763427734,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),16,65.46678161621094,65.46678161621094,65.46678161621094,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,liger,full,speed,ms,B,Batch Size (B),32,129.91734313964844,129.91734313964844,129.91734313964844,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:10,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),2,16.091487884521484,14.86076831817627,16.23084831237793,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),4,28.04204750061035,28.03957176208496,28.055641174316406,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),8,54.70073699951172,54.70073699951172,54.70073699951172,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),16,108.09929656982422,108.09929656982422,108.09929656982422,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,huggingface,full,speed,ms,B,Batch Size (B),32,215.1945343017578,215.1945343017578,215.1945343017578,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:32,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),2,3037.75390625,3037.75390625,3037.75390625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),4,3800.0126953125,3800.0126953125,3800.0126953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),8,4565.28076171875,4565.28076171875,4565.28076171875,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),16,4589.31787109375,4589.31787109375,4589.31787109375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,liger,full,memory,MB,B,Batch Size (B),32,4637.39208984375,4637.39208984375,4637.39208984375,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:06:59,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),2,4793.7626953125,4793.7626953125,4793.7626953125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),4,6551.2978515625,6551.2978515625,6551.2978515625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),8,10063.3681640625,10063.3681640625,10063.3681640625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),16,17093.5078125,17093.5078125,17093.5078125,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
kto_loss,huggingface,full,memory,MB,B,Batch Size (B),32,31153.7890625,31153.7890625,31153.7890625,"{""T"": 512, ""H"": 1024, ""V"": 128256, ""mode"": ""forward"", ""dtype"": ""torch.bfloat16"", ""bias"": true, ""beta"": 0.1, ""ignore_index"": 42}",NVIDIA A100-SXM4-80GB,2024-12-13 08:07:18,0.5.1
264 changes: 264 additions & 0 deletions benchmark/scripts/benchmark_kto_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
import os
import sys

import torch
import triton
from utils import (
QUANTILES,
SingleBenchmarkRunInput,
SingleBenchmarkRunOutput,
_test_memory,
parse_benchmark_script_args,
run_benchmarks,
)

from liger_kernel.chunked_loss import LigerFusedLinearKTOLoss
from liger_kernel.utils import infer_device

device = infer_device()
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")))


class TorchKTOLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
ref_bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
from test.chunked_loss.test_kto_loss import HFKTOLoss

super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ref_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=ref_bias, dtype=dtype
)
self.kto_loss = HFKTOLoss(
ignore_index=ignore_index, beta=beta, use_ref_model=True
).get_batch_loss_metrics

def forward(self, x, ref_x, y):
return self.kto_loss(
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
)[0]


class LigerKTOLoss(torch.nn.Module):
def __init__(
self,
H: int,
V: int,
dtype: torch.dtype,
bias: bool = False,
hebiao064 marked this conversation as resolved.
Show resolved Hide resolved
ref_bias: bool = False,
ignore_index: int = -100,
beta: float = 0.1,
):
super().__init__()
self.lin = torch.nn.Linear(
in_features=H, out_features=V, bias=bias, dtype=dtype
)
self.ref_lin = torch.nn.Linear(
in_features=H, out_features=V, bias=ref_bias, dtype=dtype
)
self.kto_loss = LigerFusedLinearKTOLoss(
ignore_index=ignore_index, beta=beta, use_ref_model=True
)

def forward(self, x, ref_x, y):
return self.kto_loss(
self.lin.weight,
x,
y,
self.lin.bias,
ref_x,
self.ref_lin.weight,
self.ref_lin.bias,
)[0]


def bench_memory_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider

torch_kto_loss = TorchKTOLoss(
H=H,
V=V,
dtype=dtype,
bias=bias,
ref_bias=bias,
ignore_index=ignore_index,
beta=beta,
).to(device)

liger_kto_loss = LigerKTOLoss(
H=H,
V=V,
dtype=dtype,
bias=bias,
ref_bias=bias,
ignore_index=ignore_index,
beta=beta,
).to(device)

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)
# Target shape: [B, T]
target = torch.randint(V, (B, T), dtype=torch.long, device=device)

# Add ignore_index tokens to simulate padding
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

# Add ref_x with the same shape as _input
ref_input = torch.randn(B, T, H, device=device, dtype=dtype)

def fwd():
if provider == "liger":
return liger_kto_loss(_input, ref_input, target)
elif provider == "huggingface":
return torch_kto_loss(_input, ref_input, target)

def full():
y = fwd()
y.backward()

mem_50, mem_20, mem_80 = _test_memory(full, _iter=10, quantiles=QUANTILES)
return SingleBenchmarkRunOutput(
y_20=mem_20,
y_50=mem_50,
y_80=mem_80,
)


def bench_speed_kto_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput:
B = input.x
T = input.extra_benchmark_config["T"]
H = input.extra_benchmark_config["H"]
V = input.extra_benchmark_config["V"]
dtype = input.extra_benchmark_config["dtype"]
bias = input.extra_benchmark_config["bias"]
beta = input.extra_benchmark_config["beta"]
ignore_index = input.extra_benchmark_config["ignore_index"]
provider = input.kernel_provider
mode = input.kernel_operation_mode

torch_kto_loss = TorchKTOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)
liger_kto_loss = LigerKTOLoss(
H=H, V=V, dtype=dtype, beta=beta, ignore_index=ignore_index, bias=bias
).to(device)

# Input shape: [B, T, H]
_input = torch.randn(B, T, H, device=device, dtype=dtype)

# Target shape: [B, T]
target = torch.randint(V, (B, T), device=device, dtype=torch.long)

# Add ignore_index tokens
num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item()
indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign]
target.view(-1)[indices_to_assign] = ignore_index

# Add ref_x with the same shape as _input
ref_input = torch.randn(B, T, H, device=device, dtype=dtype)

def fwd():
if provider == "liger":
return liger_kto_loss(_input, ref_input, target)
elif provider == "huggingface":
return torch_kto_loss(_input, ref_input, target)

if mode == "forward":
ms_50, ms_20, ms_80 = triton.testing.do_bench(
fwd,
rep=100,
quantiles=QUANTILES,
)
elif mode == "backward":
y = fwd()
ms_50, ms_20, ms_80 = triton.testing.do_bench(
lambda: y.backward(retain_graph=True),
grad_to_none=[_input],
rep=100,
quantiles=QUANTILES,
)
elif mode == "full":

def full():
y = fwd()
y.backward()

ms_50, ms_20, ms_80 = triton.testing.do_bench(
full,
rep=100,
quantiles=QUANTILES,
)

return SingleBenchmarkRunOutput(
y_20=ms_20,
y_50=ms_50,
y_80=ms_80,
)


if __name__ == "__main__":
args = parse_benchmark_script_args()

common_configs = {
"kernel_name": "kto_loss",
"x_name": "B",
"x_label": "Batch Size (B)",
"x_values": [2**i for i in range(1, 6)],
"kernel_providers": ["liger", "huggingface"],
"extra_benchmark_configs": [
{
"T": 512,
"H": 1024,
"V": 128256,
"mode": "forward",
"dtype": torch.bfloat16,
"bias": True,
"beta": 0.1,
"ignore_index": 42,
}
],
"overwrite": args.overwrite,
}

run_benchmarks(
bench_test_fn=bench_speed_kto_loss,
kernel_operation_modes=["forward", "full"],
metric_name="speed",
metric_unit="ms",
**common_configs
)

run_benchmarks(
bench_test_fn=bench_memory_kto_loss,
kernel_operation_modes=["full"],
metric_name="memory",
metric_unit="MB",
**common_configs
)
2 changes: 1 addition & 1 deletion src/liger_kernel/chunked_loss/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Liger FlexChunkLoss: Alignment and Distillation loss

Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.
Liger FlexChunkLoss offers a versatile interface, delivering up to 80% memory savings and a 10% throughput boost for post-training loss functions, including alignment (DPO, ORPO, CPO, KTO) and very soon, distillation. Its flexible design supports custom losses, ensuring efficiency gains across diverse use cases.

### User interface

Expand Down
1 change: 1 addition & 0 deletions src/liger_kernel/chunked_loss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOLoss # noqa: F401
from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOLoss # noqa: F401
from liger_kernel.chunked_loss.kto_loss import LigerFusedLinearKTOLoss # noqa: F401
from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOLoss # noqa: F401
from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOLoss # noqa: F401
Loading
Loading