From 5c5a7b4f67772f2196efd69d0821b28c0e4ac45e Mon Sep 17 00:00:00 2001 From: Peng Date: Sat, 4 Jan 2025 13:35:55 +0800 Subject: [PATCH 1/6] Set z_loss_1d=None when return_z_loss=False in cross_entropy_loss to avoid tl.store fail when triton_interpret=1(for tl.device_print etc.) (#508) For [issue-507](https://github.com/linkedin/Liger-Kernel/issues/507) ## Summary In cross_entropy_loss kernel, `tl.store(loss_ptr, loss)` doesn't work when `return_z_loss=False` and `triton_interpret=1`, because loss_1d is assigned to tensor z_loss_1d, So I set `z_loss_1d = None` in this situation and it works well. ## Testing Done I test it on my code and [this most simplified example](https://github.com/linkedin/Liger-Kernel/issues/507#issuecomment-2566829506), both work well. Hardware Type: T4 GPU - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --- src/liger_kernel/ops/cross_entropy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 3fa12c1ef..1c34decf1 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -95,7 +95,8 @@ def liger_cross_entropy_kernel( return loss_ptr += program_id * loss_stride - z_loss_ptr += program_id * loss_stride + if RETURN_Z_LOSS == _TRUE: + z_loss_ptr += program_id * loss_stride if HAS_WEIGHT: weight_y = tl.load(weight_ptr + y).cast(tl.float32) @@ -296,7 +297,7 @@ def cross_entropy_forward( if return_z_loss == _TRUE.value: z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) else: - z_loss_1d = loss_1d # dummy ptr when return_z_loss == False + z_loss_1d = None # set None when return_z_loss == False target_mask = target != ignore_index n_non_ignore = target_mask.sum().item() From f8bb86fa5d351d3965cbea700a9234b984e52ff5 Mon Sep 17 00:00:00 2001 From: Chun-Mao Lai <72752478+Mecoli1219@users.noreply.github.com> Date: Tue, 7 Jan 2025 23:28:19 -0800 Subject: [PATCH 2/6] Add `aux_outputs` for CPO and SimPO (#492) ## Summary In trl [implementation](https://github.com/huggingface/trl/blob/8c49ea39ec2e11ce4c61291ff2ad59edb46522aa/trl/trainer/cpo_trainer.py#L669C1-L672C56), CPO should have 2 extra return values (`chosen_rewards`, `rejected_rewards`), but this is not implemented in Liger-kernel. ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Mecoli1219 Co-authored-by: Austin Liu --- src/liger_kernel/chunked_loss/cpo_loss.py | 6 +++++- src/liger_kernel/chunked_loss/simpo_loss.py | 5 ++++- test/chunked_loss/test_cpo_loss.py | 6 +++++- 3 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 41ec78a9d..aa7f078ba 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -33,7 +33,11 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1, labe loss = (-F.logsigmoid(logits) * (1 - label_smoothing) - F.logsigmoid(-logits) * label_smoothing).sum() / ( full_target.shape[0] // 2 ) - return loss + + chosen_rewards = beta * chosen_logps + rejected_rewards = beta * rejected_logps + + return loss, chosen_rewards, rejected_rewards @staticmethod def forward( diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index 975bcefab..7f7c75bc1 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -42,7 +42,10 @@ def preference_loss_fn( full_target.shape[0] // 2 ) - return loss + chosen_rewards = beta * chosen_logps + rejected_rewards = beta * rejected_logps + + return loss, chosen_rewards, rejected_rewards @staticmethod def forward( diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index 5f5e15ad1..a59e3db19 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -73,7 +73,11 @@ def alignment_loss( ) else: raise ValueError(f"Unknown loss type: {self.loss_type}. Should be one of ['sigmoid']") - return losses + + chosen_rewards = self.beta * policy_chosen_logps + rejected_rewards = self.beta * policy_rejected_logps + + return losses, chosen_rewards, rejected_rewards class TorchLMHeadCPO(torch.nn.Module): From 23e37726b8571bd829db16add71db9b2ce6b3b97 Mon Sep 17 00:00:00 2001 From: Chun-Mao Lai <72752478+Mecoli1219@users.noreply.github.com> Date: Tue, 7 Jan 2025 23:33:18 -0800 Subject: [PATCH 3/6] Add `average_log_prob` args for cpo (#510) ## Summary `trl` CPO implementation didn't average the log_probs, while the liger kernel averages it when computing the loss. This will cause a mismatch when integrating them. ## Testing Done - [x] Updating unit test (still investigating why unit test fail locally) - Hardware Type: - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence --------- Signed-off-by: Mecoli1219 Co-authored-by: Kashif Rasul Co-authored-by: Austin Liu --- src/liger_kernel/chunked_loss/cpo_loss.py | 1 + .../chunked_loss/fused_linear_preference.py | 17 ++++++++++++++--- test/chunked_loss/test_cpo_loss.py | 5 +++-- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index aa7f078ba..5a83ab026 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -65,6 +65,7 @@ def forward( beta=beta, label_smoothing=label_smoothing, compute_nll_loss=compute_nll_loss, + average_log_prob=False, compiled=compiled, ) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index f389050c0..31d95480c 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -32,6 +32,7 @@ def forward( ref_input=None, ref_weight=None, ref_bias=None, + average_log_prob=True, **loss_kwargs, ): """ @@ -61,6 +62,7 @@ def forward( use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + average_log_prob (bool): Whether to average log probabilities or to sum them over the completion. loss_kwargs (dict): Other possible arguments that a loss function might need """ # TODO: Tune CHUNK_SIZE to fully utilize the GPU @@ -94,6 +96,7 @@ def forward( use_ref_model=use_ref_model, ref_weight=ref_weight, ref_bias=ref_bias, + average_log_prob=average_log_prob, **loss_kwargs, ) @@ -265,6 +268,7 @@ def chunk_forward( bias=None, ignore_index=-100, compute_nll_loss=True, + average_log_prob=True, ): len_chosen_chunk = target_chunk.shape[0] // 2 logits_chunk = input_chunk @ weight.t() @@ -285,10 +289,13 @@ def chunk_forward( label_chunk = torch.where(loss_mask, target_chunk, 0) per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze(-1) - average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + if average_log_prob: + log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + else: + log_prob = (per_token_logps * loss_mask).sum(-1) - chosen_logps = average_log_prob[:len_chosen_chunk] - rejected_logps = average_log_prob[len_chosen_chunk:] + chosen_logps = log_prob[:len_chosen_chunk] + rejected_logps = log_prob[len_chosen_chunk:] chosen_logits = logits_chunk[:len_chosen_chunk] rejected_logits = logits_chunk[len_chosen_chunk:] @@ -317,6 +324,7 @@ def _compute_loss( ref_input_chunk=None, ref_weight=None, ref_bias=None, + average_log_prob=True, **loss_kwargs, ): """ @@ -335,6 +343,7 @@ def _compute_loss( use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + average_log_prob (bool): Whether to average log probabilities or the sum. loss_kwargs (dict): Additional arguments for the loss function. """ ( @@ -350,6 +359,7 @@ def _compute_loss( bias=bias, ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, + average_log_prob=average_log_prob, ) chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]) @@ -372,6 +382,7 @@ def _compute_loss( ref_bias, ignore_index=ignore_index, compute_nll_loss=False, # We don't need NLL loss for the reference model + average_log_prob=average_log_prob, ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps loss_kwargs["ref_rejected_logps"] = ref_rejected_logps diff --git a/test/chunked_loss/test_cpo_loss.py b/test/chunked_loss/test_cpo_loss.py index a59e3db19..c996e57f9 100644 --- a/test/chunked_loss/test_cpo_loss.py +++ b/test/chunked_loss/test_cpo_loss.py @@ -103,9 +103,10 @@ def __init__( label_smoothing=label_smoothing, simpo_gamma=simpo_gamma, ).get_batch_loss_metrics + self.average_log_prob = loss_type == "simpo" def forward(self, x, y): - return self.cpo_loss(self.lin.weight, x, y, self.lin.bias) + return self.cpo_loss(self.lin.weight, x, y, self.lin.bias, average_log_prob=self.average_log_prob) class LigerLMHeadCPO(torch.nn.Module): @@ -143,7 +144,7 @@ def forward(self, x, y): @pytest.mark.parametrize( "scalar, dtype, atol, rtol", [ - (1.0, torch.bfloat16, 5e-3, 5e-3), + (1.0, torch.bfloat16, 5e-2, 5e-2), (1.0, torch.float32, 1e-5, 5e-4), ], ) From 134a13e8f0ef59a98e046a61a85197902b958cea Mon Sep 17 00:00:00 2001 From: Tcc0403 <76503978+Tcc0403@users.noreply.github.com> Date: Wed, 8 Jan 2025 18:21:07 +0800 Subject: [PATCH 4/6] Refactor CrossEntropy and FusedLinearCrossEntropy (#511) ## Summary Remove redundant codes. ## Testing Done - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --- src/liger_kernel/ops/cross_entropy.py | 32 +++++-------------- .../ops/fused_linear_cross_entropy.py | 8 ++--- .../transformers/cross_entropy.py | 3 -- 3 files changed, 12 insertions(+), 31 deletions(-) diff --git a/src/liger_kernel/ops/cross_entropy.py b/src/liger_kernel/ops/cross_entropy.py index 1c34decf1..9e4ab69e8 100644 --- a/src/liger_kernel/ops/cross_entropy.py +++ b/src/liger_kernel/ops/cross_entropy.py @@ -20,9 +20,6 @@ else: from triton.language.math import tanh -_TRUE: tl.constexpr = tl.constexpr(1) -_FALSE: tl.constexpr = tl.constexpr(0) - @triton.jit def liger_cross_entropy_kernel( @@ -95,7 +92,7 @@ def liger_cross_entropy_kernel( return loss_ptr += program_id * loss_stride - if RETURN_Z_LOSS == _TRUE: + if RETURN_Z_LOSS: z_loss_ptr += program_id * loss_stride if HAS_WEIGHT: @@ -254,7 +251,7 @@ def liger_cross_entropy_kernel( loss += z_loss tl.store(loss_ptr, loss) - if RETURN_Z_LOSS == _TRUE: + if RETURN_Z_LOSS: tl.store(z_loss_ptr, z_loss) @@ -264,12 +261,6 @@ def liger_cross_entropy_kernel( MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning -_bool_to_return_z_loss = { - True: _TRUE.value, - False: _FALSE.value, -} - - def cross_entropy_forward( _input, target, @@ -281,11 +272,7 @@ def cross_entropy_forward( softcap, return_z_loss, ): - if not isinstance(return_z_loss, int): - assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}" - return_z_loss = _bool_to_return_z_loss[return_z_loss] - else: - assert return_z_loss in _bool_to_return_z_loss, f"return_z_loss must be True or False. Got: {return_z_loss}" + assert isinstance(return_z_loss, bool), f"return_z_loss must be True or False. Got: {return_z_loss}" BT, V = _input.shape n_rows = BT @@ -294,10 +281,7 @@ def cross_entropy_forward( # unreduced loss loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) - if return_z_loss == _TRUE.value: - z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) - else: - z_loss_1d = None # set None when return_z_loss == False + z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device) if return_z_loss else None target_mask = target != ignore_index n_non_ignore = target_mask.sum().item() @@ -326,7 +310,7 @@ def cross_entropy_forward( X_stride=_input.stride(-2), Y_ptr=target, Y_stride=target.stride(-1), # always 1 - weight_ptr=weight if weight is not None else _input, # dummy if None + weight_ptr=weight, # dummy if None loss_ptr=loss_1d, z_loss_ptr=z_loss_1d, loss_stride=loss_1d.stride(-1), # always 1 @@ -338,7 +322,7 @@ def cross_entropy_forward( lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, - softcap=softcap if softcap is not None else 0.0, + softcap=softcap, RETURN_Z_LOSS=return_z_loss, BLOCK_SIZE=BLOCK_SIZE, HAS_WEIGHT=True if weight is not None else False, @@ -350,10 +334,10 @@ def cross_entropy_forward( if reduction == "none": loss = loss_1d - z_loss = z_loss_1d if return_z_loss == _TRUE.value else None + z_loss = z_loss_1d if return_z_loss else None else: loss = torch.sum(loss_1d) - z_loss = torch.sum(z_loss_1d) if return_z_loss == _TRUE.value else None + z_loss = torch.sum(z_loss_1d) if return_z_loss else None return loss, z_loss, _input diff --git a/src/liger_kernel/ops/fused_linear_cross_entropy.py b/src/liger_kernel/ops/fused_linear_cross_entropy.py index 4df484135..d62ff40bb 100644 --- a/src/liger_kernel/ops/fused_linear_cross_entropy.py +++ b/src/liger_kernel/ops/fused_linear_cross_entropy.py @@ -92,9 +92,9 @@ def fused_linear_cross_entropy_forward( X_stride=logits_chunk.stride(-2), Y_ptr=target_chunk, Y_stride=target_chunk.stride(-1), # always 1 - weight_ptr=ce_weight if ce_weight is not None else _input, # dummy if None + weight_ptr=ce_weight, loss_ptr=loss_1d_slice, - z_loss_ptr=loss_1d_slice, # dummy ptr, not used + z_loss_ptr=None, loss_stride=loss_1d_slice.stride(-1), # always 1 n_cols=V, n_non_ignore=total_n_non_ignore, @@ -104,8 +104,8 @@ def fused_linear_cross_entropy_forward( lse_square_scale=lse_square_scale, label_smoothing=label_smoothing, reduction=reduction, - softcap=softcap if softcap is not None else 0.0, - RETURN_Z_LOSS=0, # False + softcap=softcap, + RETURN_Z_LOSS=False, HAS_WEIGHT=True if ce_weight is not None else False, HAS_SOFTCAPPING=True if softcap is not None else False, BLOCK_SIZE=BLOCK_SIZE, diff --git a/src/liger_kernel/transformers/cross_entropy.py b/src/liger_kernel/transformers/cross_entropy.py index d72fc3b00..ac522c2e0 100644 --- a/src/liger_kernel/transformers/cross_entropy.py +++ b/src/liger_kernel/transformers/cross_entropy.py @@ -20,9 +20,6 @@ def __init__( assert (label_smoothing >= 0) and ( label_smoothing <= 1 ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" - assert (label_smoothing >= 0) and ( - label_smoothing <= 1 - ), f"label_smoothing must be between 0.0 and 1.0. Got: {label_smoothing}" assert reduction in { "mean", "sum", From 9586a8766d885bbd44fa874b5e04fbe3dfd436c0 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Thu, 9 Jan 2025 03:36:55 +0100 Subject: [PATCH 5/6] [ORPO] add nll_target for orpo nll loss (#503) ## Summary add optional nll_target argument to calculate nll (needed for ORPO nll loss) - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --- benchmark/scripts/benchmark_orpo_loss.py | 10 ++-- .../chunked_loss/fused_linear_preference.py | 52 ++++++++++++++----- src/liger_kernel/chunked_loss/orpo_loss.py | 7 ++- .../transformers/trainer/orpo_trainer.py | 20 +++++-- test/chunked_loss/test_orpo_loss.py | 18 ++++--- test/utils.py | 8 +-- 6 files changed, 82 insertions(+), 33 deletions(-) diff --git a/benchmark/scripts/benchmark_orpo_loss.py b/benchmark/scripts/benchmark_orpo_loss.py index a94992019..e86129966 100644 --- a/benchmark/scripts/benchmark_orpo_loss.py +++ b/benchmark/scripts/benchmark_orpo_loss.py @@ -45,12 +45,13 @@ def bench_memory_fused_linear_orpo_loss( _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device) + nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device) def fwd(): if provider == "liger": - return liger_lm_head_orpo(_input, target) + return liger_lm_head_orpo(_input, target, nll_target) elif provider == "huggingface": - return torch_lm_head_orpo(_input, target) + return torch_lm_head_orpo(_input, target, nll_target) def full(): y = fwd() @@ -91,12 +92,13 @@ def bench_speed_fused_linear_orpo_loss( _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device) + nll_target = torch.randint(V, (B, T), dtype=torch.long, device=device) def fwd(): if provider == "liger": - return liger_lm_head_orpo(_input, target) + return liger_lm_head_orpo(_input, target, nll_target) elif provider == "huggingface": - return torch_lm_head_orpo(_input, target) + return torch_lm_head_orpo(_input, target, nll_target) if mode == "forward": ms_50, ms_20, ms_80 = triton.testing.do_bench( diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index 31d95480c..8b3ec5255 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -27,6 +27,7 @@ def forward( alpha=1.0, beta=0.1, compute_nll_loss=True, + nll_target=None, compiled=True, use_ref_model=False, ref_input=None, @@ -58,6 +59,7 @@ def forward( alpha (float): Weight for the NLL loss. beta (float): Weight for the preference loss. compute_nll_loss (bool): Whether to compute NLL loss. + nll_target (torch.Tensor, optional): Target tensor for NLL loss. Shape: (batch_size, seq_len). If not provided the target is used. compiled (bool): Whether to use torch compile for chunk accumulation. use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). @@ -96,11 +98,12 @@ def forward( use_ref_model=use_ref_model, ref_weight=ref_weight, ref_bias=ref_bias, + full_nll_target=nll_target, average_log_prob=average_log_prob, **loss_kwargs, ) - def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk): + def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk): """ Fused forward and backward pass for a chunk of input and target. """ @@ -111,13 +114,18 @@ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk): target_chunk, bias, ref_input_chunk=ref_input_chunk, + chosen_nll_target_chunk=chosen_nll_target_chunk, ) else: return torch.func.grad_and_value(compute_loss, argnums=(0, 1), has_aux=True)( - input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk + input_chunk, + weight, + target_chunk, + ref_input_chunk=ref_input_chunk, + chosen_nll_target_chunk=chosen_nll_target_chunk, ) - def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): + def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None, chosen_nll_target_chunk=None): if bias is not None: ( (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), @@ -132,7 +140,7 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): *aux_outputs, ), ), - ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk) + ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk) grad_bias.add_(chunk_grad_bias) # accumulate bias gradient else: ( @@ -148,7 +156,7 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): *aux_outputs, ), ), - ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk) + ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk) # Accumulate gradients grad_weight.add_(chunk_grad_weight) @@ -191,6 +199,9 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): _rejected_input_chunks = torch.chunk(_input[len_chosen:], chunks=chunks, dim=0) _rejected_target_chunks = torch.chunk(target[len_chosen:], chunks=chunks, dim=0) + if nll_target is not None: + _chosen_nll_target_chunks = torch.chunk(nll_target[:len_chosen], chunks=chunks, dim=0) + if use_ref_model: _ref_chosen_input_chunks = torch.chunk(ref_input[:len_chosen], chunks=chunks, dim=0) _ref_rejected_input_chunks = torch.chunk(ref_input[len_chosen:], chunks=chunks, dim=0) @@ -202,6 +213,7 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): rejected_target_chunk, ref_chosen_input_chunk, ref_rejected_input_chunk, + chosen_nll_target_chunk, ) in zip( _chosen_input_chunks, _rejected_input_chunks, @@ -209,6 +221,7 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): _rejected_target_chunks, (_ref_chosen_input_chunks if use_ref_model else [None] * len(_chosen_input_chunks)), (_ref_rejected_input_chunks if use_ref_model else [None] * len(_rejected_input_chunks)), + (_chosen_nll_target_chunks if nll_target is not None else [None] * len(_chosen_input_chunks)), strict=False, ): input_chunk = torch.cat([chosen_input_chunk, rejected_input_chunk], dim=0) @@ -222,9 +235,10 @@ def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): torch._dynamo.mark_dynamic(target_chunk, 1) torch._dynamo.mark_dynamic(target, 1) torch._dynamo.mark_dynamic(ref_input_chunk, 1) if use_ref_model else None + torch._dynamo.mark_dynamic(chosen_nll_target_chunk, 1) if nll_target is not None else None # accumulate loss, gradients, and metrics - accumulate_chunk(input_chunk, target_chunk, ref_input_chunk) + accumulate_chunk(input_chunk, target_chunk, ref_input_chunk, chosen_nll_target_chunk) # combine grad_chosen_inputs and grad_rejected_inputs grad_inputs = grad_chosen_inputs + grad_rejected_inputs @@ -258,7 +272,7 @@ def backward(ctx, *grad_output): grad_weight = grad_weight * grad_output[0][0] grad_bias = grad_bias * grad_output[0][0] if grad_bias is not None else None - return grad_input, grad_weight, None, grad_bias, None, None, None + return grad_input, grad_weight, None, grad_bias, None, None, None, None @staticmethod def chunk_forward( @@ -268,6 +282,7 @@ def chunk_forward( bias=None, ignore_index=-100, compute_nll_loss=True, + chosen_nll_target_chunk=None, average_log_prob=True, ): len_chosen_chunk = target_chunk.shape[0] // 2 @@ -278,9 +293,12 @@ def chunk_forward( chosen_nll_loss = 0.0 if compute_nll_loss: + nll_labels = ( + chosen_nll_target_chunk if chosen_nll_target_chunk is not None else target_chunk[:len_chosen_chunk] + ) chosen_nll_loss = F.nll_loss( log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), - target_chunk[:len_chosen_chunk].view(-1), + nll_labels.view(-1), reduction="sum", ignore_index=ignore_index, ) @@ -324,6 +342,8 @@ def _compute_loss( ref_input_chunk=None, ref_weight=None, ref_bias=None, + full_nll_target=None, + chosen_nll_target_chunk=None, average_log_prob=True, **loss_kwargs, ): @@ -343,6 +363,8 @@ def _compute_loss( use_ref_model (bool): Whether to use a reference model for the alignment loss. ref_weight (torch.Tensor): Reference weight tensor. Shape: (vocab_size, hidden_size). ref_bias (torch.Tensor, optional): Reference bias tensor. Shape: (vocab_size,). + full_nll_target (torch.Tensor, optional): Full target tensor for NLL loss. Shape: (batch_size, sequence_length). + chosen_nll_target_chunk (torch.Tensor, optional): Target tensor for NLL loss. Shape: (chunk_size, sequence_length) If not provided the target_chunk is used. average_log_prob (bool): Whether to average log probabilities or the sum. loss_kwargs (dict): Additional arguments for the loss function. """ @@ -359,9 +381,14 @@ def _compute_loss( bias=bias, ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, + chosen_nll_target_chunk=chosen_nll_target_chunk, average_log_prob=average_log_prob, ) - chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + if full_nll_target is not None: + chosen_nll_loss = chosen_nll_loss / (full_nll_target[: full_nll_target.shape[0] // 2] != ignore_index).sum() + else: + chosen_nll_loss = chosen_nll_loss / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() + chosen_logits_mean = chosen_logits.sum() / (full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0]) rejected_logits_mean = rejected_logits.sum() / ( full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] @@ -372,9 +399,9 @@ def _compute_loss( ( ref_chosen_logps, ref_rejected_logps, - ref_chosen_logits, - ref_rejected_logits, - ref_chosen_nll_loss, + _, + _, + _, ) = LigerFusedLinearPreferenceBase.chunk_forward( ref_input_chunk, ref_weight, @@ -382,6 +409,7 @@ def _compute_loss( ref_bias, ignore_index=ignore_index, compute_nll_loss=False, # We don't need NLL loss for the reference model + chosen_nll_target_chunk=None, average_log_prob=average_log_prob, ) loss_kwargs["ref_chosen_logps"] = ref_chosen_logps diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index dfed5d3a7..9bd553716 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -52,6 +52,7 @@ def forward( ignore_index=-100, beta=0.1, compute_nll_loss=True, + nll_target=None, compiled=True, ): return LigerFusedLinearPreferenceBase.forward( @@ -64,13 +65,14 @@ def forward( ignore_index=ignore_index, beta=beta, compute_nll_loss=compute_nll_loss, + nll_target=nll_target, compiled=compiled, ) @staticmethod def backward(ctx, *grad_output): grads = LigerFusedLinearPreferenceBase.backward(ctx, grad_output)[:4] - return *grads, None, None, None, None + return *grads, None, None, None, None, None class LigerFusedLinearORPOLoss(torch.nn.Module): @@ -96,7 +98,7 @@ def __init__( self.compute_nll_loss = compute_nll_loss self.compiled = compiled - def forward(self, lin_weight, _input, target, bias=None): + def forward(self, lin_weight, _input, target, bias=None, nll_target=None): return LigerFusedLinearORPOFunction.apply( _input, lin_weight, @@ -105,5 +107,6 @@ def forward(self, lin_weight, _input, target, bias=None): self.ignore_index, self.beta, self.compute_nll_loss, + nll_target, self.compiled, ) diff --git a/src/liger_kernel/transformers/trainer/orpo_trainer.py b/src/liger_kernel/transformers/trainer/orpo_trainer.py index ca54733d0..2a4dc377c 100644 --- a/src/liger_kernel/transformers/trainer/orpo_trainer.py +++ b/src/liger_kernel/transformers/trainer/orpo_trainer.py @@ -93,6 +93,13 @@ def concatenated_forward( if self.aux_loss_enabled: model_kwargs["output_router_logits"] = True + if self.is_encoder_decoder: + labels = concatenated_batch["concatenated_labels"].clone() + else: + labels = concatenated_batch["concatenated_input_ids"].clone() + attention_mask = concatenated_batch["concatenated_attention_mask"] + labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id) + if isinstance(model, FullyShardedDataParallel): outputs = _FSDPForwardRedirection()( model, @@ -114,15 +121,20 @@ def concatenated_forward( orpo_loss_fn = LigerFusedLinearORPOLoss(ignore_index=self.label_pad_token_id, beta=self.beta) - def orpo_partial(lm_head, last_hidden_state, concatenated_labels): - return orpo_loss_fn(lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias) + def orpo_partial(lm_head, last_hidden_state, concatenated_labels, nll_target): + return orpo_loss_fn( + lm_head.weight, last_hidden_state, concatenated_labels, lm_head.bias, nll_target=nll_target + ) orpo_loss, aux_outputs = _FSDPForwardRedirection()( model, orpo_partial, model.lm_head, - outputs.last_hidden_state, - concatenated_batch["concatenated_labels"], + outputs.last_hidden_state[:, :-1] if not self.is_encoder_decoder else outputs.last_hidden_state, + concatenated_batch["concatenated_labels"][:, 1:] + if not self.is_encoder_decoder + else concatenated_batch["concatenated_labels"], + labels[:, 1:] if not self.is_encoder_decoder else labels, ) # if aux_loss_enabled, add the aux_loss to the orpo_loss if self.aux_loss_enabled: diff --git a/test/chunked_loss/test_orpo_loss.py b/test/chunked_loss/test_orpo_loss.py index 529b7dff7..3bbd245a1 100644 --- a/test/chunked_loss/test_orpo_loss.py +++ b/test/chunked_loss/test_orpo_loss.py @@ -86,8 +86,8 @@ def __init__( self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.orpo_loss = HFORPOLoss(ignore_index=ignore_index, beta=beta).get_batch_loss_metrics - def forward(self, x, y): - return self.orpo_loss(self.lin.weight, x, y, self.lin.bias) + def forward(self, x, y, nll_target=None): + return self.orpo_loss(self.lin.weight, x, y, self.lin.bias, nll_target=nll_target) class LigerLMHeadORPO(torch.nn.Module): @@ -104,8 +104,8 @@ def __init__( self.lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype) self.orpo_loss = LigerFusedLinearORPOLoss(ignore_index=ignore_index, beta=beta) - def forward(self, x, y): - return self.orpo_loss(self.lin.weight, x, y, self.lin.bias) + def forward(self, x, y, nll_target=None): + return self.orpo_loss(self.lin.weight, x, y, self.lin.bias, nll_target=nll_target) @pytest.mark.parametrize( @@ -164,13 +164,15 @@ def test_correctness(B, T, H, V, scalar, dtype, atol, rtol, bias, ignore_index, device=device, dtype=torch.long, ) + nll_target = torch.randint(0, V, (B, T), device=device, dtype=torch.long) + # Assign some random number of elements as ignore_index num_elements_to_assign = torch.randint(1, B * T // 2, (1,)).item() indices_to_assign = torch.randperm(B * T)[:num_elements_to_assign] target.view(-1)[indices_to_assign] = ignore_index - loss1, aggregated_aux_outputs1 = torch_lm_head_orpo(input1, target) - loss2, aggregated_aux_outputs2 = liger_lm_head_orpo(input2, target) + loss1, aggregated_aux_outputs1 = torch_lm_head_orpo(input1, target, nll_target) + loss2, aggregated_aux_outputs2 = liger_lm_head_orpo(input2, target, nll_target) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) @@ -244,8 +246,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, atol, rtol, bias): bias1 = _bias.detach().clone().requires_grad_(True) if bias else None bias2 = _bias.detach().clone().requires_grad_(True) if bias else None - loss1, aggregated_aux_outputs1 = LigerFusedLinearORPOFunction.apply(input1, weight1, target, bias1) - loss2, aggregated_aux_outputs2 = liger_fused_linear_orpo(input2, weight2, target, bias2) + loss1, _ = LigerFusedLinearORPOFunction.apply(input1, weight1, target, bias1) + loss2, _ = liger_fused_linear_orpo(input2, weight2, target, bias2) assert_verbose_allclose(loss1, loss2, atol=atol, rtol=rtol) diff --git a/test/utils.py b/test/utils.py index ec4abd5a8..283d04cb5 100644 --- a/test/utils.py +++ b/test/utils.py @@ -406,8 +406,9 @@ def concatenated_forward( _input: torch.FloatTensor, weight: torch.FloatTensor, target: torch.LongTensor, - bias: torch.FloatTensor = None, + bias: torch.FloatTensor | None = None, average_log_prob: bool = True, + nll_target: torch.LongTensor | None = None, ) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]: """Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together. @@ -430,7 +431,7 @@ def cross_entropy_loss(logits, labels): loss = loss_fct(logits, labels) return loss - labels = target + labels = nll_target if nll_target is not None else target chosen_nll_loss = torch.tensor(0.0, device=all_logits.device) if self.compute_nll_loss: chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen]) @@ -465,10 +466,11 @@ def get_batch_loss_metrics( ref_weight: torch.FloatTensor = None, ref_bias: torch.FloatTensor = None, average_log_prob: bool = True, + nll_target: torch.LongTensor = None, ): """Compute the loss metrics for the given batch of inputs for train or test.""" - forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob) + forward_output = self.concatenated_forward(_input, weight, target, bias, average_log_prob, nll_target) ( policy_chosen_logps, policy_rejected_logps, From ba72b8e249bd04e798aa85e5d6f072feabed23c1 Mon Sep 17 00:00:00 2001 From: Austin Liu Date: Fri, 10 Jan 2025 18:29:07 +0800 Subject: [PATCH 6/6] Format Benchmark Scripts with Ruff (#516) ## Summary Format benchmark scripts with (newly migrated) Ruff after https://github.com/linkedin/Liger-Kernel/pull/483. ## Testing Done - Hardware Type: - [ ] run `make test` to ensure correctness - [X] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence Signed-off-by: Austin Liu --- benchmark/scripts/benchmark_cpo_loss.py | 23 ++++++++--------------- benchmark/scripts/benchmark_dpo_loss.py | 8 ++++---- benchmark/scripts/benchmark_orpo_loss.py | 23 ++++++++--------------- benchmark/scripts/benchmark_simpo_loss.py | 23 ++++++++--------------- 4 files changed, 28 insertions(+), 49 deletions(-) diff --git a/benchmark/scripts/benchmark_cpo_loss.py b/benchmark/scripts/benchmark_cpo_loss.py index 54aa212cd..8d8ccea82 100644 --- a/benchmark/scripts/benchmark_cpo_loss.py +++ b/benchmark/scripts/benchmark_cpo_loss.py @@ -11,7 +11,6 @@ from utils import parse_benchmark_script_args from utils import run_benchmarks -from liger_kernel.chunked_loss.cpo_loss import LigerFusedLinearCPOFunction from liger_kernel.utils import infer_device device = infer_device() @@ -27,7 +26,8 @@ def bench_memory_fused_linear_cpo_loss( input: SingleBenchmarkRunInput, ) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO, TorchLMHeadCPO + from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO + from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO B = input.x T = input.extra_benchmark_config["T"] @@ -36,12 +36,8 @@ def bench_memory_fused_linear_cpo_loss( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to( - device - )(x, target)[0] - liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to( - device - )(x, target)[0] + torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0] + liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0] _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device) @@ -72,7 +68,8 @@ def full(): def bench_speed_fused_linear_cpo_loss( input: SingleBenchmarkRunInput, ) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO, TorchLMHeadCPO + from test.chunked_loss.test_cpo_loss import LigerLMHeadCPO + from test.chunked_loss.test_cpo_loss import TorchLMHeadCPO B = input.x T = input.extra_benchmark_config["T"] @@ -82,12 +79,8 @@ def bench_speed_fused_linear_cpo_loss( provider = input.kernel_provider mode = input.kernel_operation_mode - torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to( - device - )(x, target)[0] - liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to( - device - )(x, target)[0] + torch_lm_head_cpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0] + liger_lm_head_cpo = lambda x, target: LigerLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0] _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device) diff --git a/benchmark/scripts/benchmark_dpo_loss.py b/benchmark/scripts/benchmark_dpo_loss.py index 5078f6194..5eb83c3d3 100644 --- a/benchmark/scripts/benchmark_dpo_loss.py +++ b/benchmark/scripts/benchmark_dpo_loss.py @@ -4,7 +4,6 @@ import torch import triton -from test.chunked_loss.test_dpo_loss import HF_DPO_Loss from utils import QUANTILES from utils import SingleBenchmarkRunInput from utils import SingleBenchmarkRunOutput @@ -12,7 +11,6 @@ from utils import parse_benchmark_script_args from utils import run_benchmarks -from liger_kernel.chunked_loss.dpo_loss import LigerFusedLinearDPOFunction from liger_kernel.utils import infer_device device = infer_device() @@ -21,7 +19,8 @@ def bench_memory_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO, TorchLMHeadDPO + from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO + from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO B = input.x T = input.extra_benchmark_config["T"] @@ -70,7 +69,8 @@ def full(): def bench_speed_dpo_loss(input: SingleBenchmarkRunInput) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO, TorchLMHeadDPO + from test.chunked_loss.test_dpo_loss import LigerLMHeadDPO + from test.chunked_loss.test_dpo_loss import TorchLMHeadDPO B = input.x T = input.extra_benchmark_config["T"] diff --git a/benchmark/scripts/benchmark_orpo_loss.py b/benchmark/scripts/benchmark_orpo_loss.py index e86129966..65202ab98 100644 --- a/benchmark/scripts/benchmark_orpo_loss.py +++ b/benchmark/scripts/benchmark_orpo_loss.py @@ -11,7 +11,6 @@ from utils import parse_benchmark_script_args from utils import run_benchmarks -from liger_kernel.chunked_loss.orpo_loss import LigerFusedLinearORPOFunction from liger_kernel.utils import infer_device device = infer_device() @@ -27,7 +26,8 @@ def bench_memory_fused_linear_orpo_loss( input: SingleBenchmarkRunInput, ) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO, TorchLMHeadORPO + from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO + from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO B = input.x T = input.extra_benchmark_config["T"] @@ -36,12 +36,8 @@ def bench_memory_fused_linear_orpo_loss( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to( - device - )(x, target)[0] - liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to( - device - )(x, target)[0] + torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0] + liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0] _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device) @@ -73,7 +69,8 @@ def full(): def bench_speed_fused_linear_orpo_loss( input: SingleBenchmarkRunInput, ) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO, TorchLMHeadORPO + from test.chunked_loss.test_orpo_loss import LigerLMHeadORPO + from test.chunked_loss.test_orpo_loss import TorchLMHeadORPO B = input.x T = input.extra_benchmark_config["T"] @@ -83,12 +80,8 @@ def bench_speed_fused_linear_orpo_loss( provider = input.kernel_provider mode = input.kernel_operation_mode - torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to( - device - )(x, target)[0] - liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to( - device - )(x, target)[0] + torch_lm_head_orpo = lambda x, target: TorchLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0] + liger_lm_head_orpo = lambda x, target: LigerLMHeadORPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0] _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device) diff --git a/benchmark/scripts/benchmark_simpo_loss.py b/benchmark/scripts/benchmark_simpo_loss.py index 13eaaed33..99daa8e52 100644 --- a/benchmark/scripts/benchmark_simpo_loss.py +++ b/benchmark/scripts/benchmark_simpo_loss.py @@ -11,7 +11,6 @@ from utils import parse_benchmark_script_args from utils import run_benchmarks -from liger_kernel.chunked_loss.simpo_loss import LigerFusedLinearSimPOFunction from liger_kernel.utils import infer_device device = infer_device() @@ -27,7 +26,8 @@ def bench_memory_fused_linear_simpo_loss( input: SingleBenchmarkRunInput, ) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO, TorchLMHeadCPO + from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO + from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO B = input.x T = input.extra_benchmark_config["T"] @@ -36,12 +36,8 @@ def bench_memory_fused_linear_simpo_loss( dtype = input.extra_benchmark_config["dtype"] provider = input.kernel_provider - torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to( - device - )(x, target)[0] - liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to( - device - )(x, target)[0] + torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0] + liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0] _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device) @@ -72,7 +68,8 @@ def full(): def bench_speed_fused_linear_simpo_loss( input: SingleBenchmarkRunInput, ) -> SingleBenchmarkRunOutput: - from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO, TorchLMHeadCPO + from test.chunked_loss.test_simpo_loss import LigerLMHeadSimPO + from test.chunked_loss.test_simpo_loss import TorchLMHeadCPO B = input.x T = input.extra_benchmark_config["T"] @@ -82,12 +79,8 @@ def bench_speed_fused_linear_simpo_loss( provider = input.kernel_provider mode = input.kernel_operation_mode - torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to( - device - )(x, target)[0] - liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to( - device - )(x, target)[0] + torch_lm_head_simpo = lambda x, target: TorchLMHeadCPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0] + liger_lm_head_simpo = lambda x, target: LigerLMHeadSimPO(H=H, V=V, dtype=dtype).to(device)(x, target)[0] _input = torch.randn(B, T, H, requires_grad=True, dtype=dtype, device=device) target = torch.randint(V, (B, T), dtype=torch.long, device=device)