Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cross Entropy Loss performance issue #5509

Open
wa008 opened this issue Dec 30, 2024 · 0 comments
Open

Cross Entropy Loss performance issue #5509

wa008 opened this issue Dec 30, 2024 · 0 comments

Comments

@wa008
Copy link

wa008 commented Dec 30, 2024

Describe the issue

I implemented cross-entropy using Triton, but the performance is disappointingly low. Even after removing most of the code in the loss_kernel (producing incorrect results), the performance remains significantly worse than PyTorch. Are there common pitfalls I might have encountered, or does PyTorch apply specific optimizations that I might be missing?

image
  • Torch_lib: torch.nn.CrossEntropyLoss
  • torch_loss_realize1: realise CrossEntropyLoss by Pytorch
  • torch_loss_realize2: realise optimized CrossEntropyLoss by Pytorch
  • Triton: realise optimized CrossEntropyLoss by Triton

loss_kernel:

@triton.jit
def loss_kernel(input_ptr,
                target_ptr,
                output_ptr,
                M, N,
                BLOCK_SIZE_M: tl.constexpr,
                BLOCK_SIZE_N: tl.constexpr,
               ):
    pid_m = tl.program_id(axis=0)
    offsets_rows = (BLOCK_SIZE_M * pid_m + tl.arange(0, BLOCK_SIZE_M))
    offsets_cols = tl.arange(0, BLOCK_SIZE_N)

    target = tl.load(target_ptr + offsets_rows, mask = offsets_rows < M, other = 0.0)
    max_val = tl.full(target.shape, -float("inf"), dtype = tl.float32)
    sumexp = tl.zeros_like(max_val)

    allcurx = tl.zeros_like(max_val)
    for index in tl.static_range(0, N, BLOCK_SIZE_N):
        offsets_input = offsets_rows[:, None] * N + (offsets_cols + index)[None, :]
        mask_input = (offsets_rows[:, None] < M) & ((offsets_cols + index)[None, :] < N)
        input_val = tl.load(input_ptr + offsets_input, mask=mask_input, other = -float("inf"))

        if index == 0:
            new_max_val = input_val.max(axis = 1)
        else:
            new_max_val = tl.maximum(max_val, input_val.max(axis = 1))

        sumexp /= tl.exp(new_max_val)
        sumexp *= tl.exp(max_val)
        sumexp += tl.sum(tl.exp(input_val - new_max_val[:, None]), axis = 1)
        max_val = new_max_val

        curx = tl.load(input_ptr + offsets_rows * N + target, mask=(target >= index) & (target < index + BLOCK_SIZE_N), other = 0)
        allcurx += curx
    output = tl.log(sumexp) + max_val - allcurx
    tl.store(output_ptr + offsets_rows, output, mask = offsets_rows < M)

def triton_loss(input: torch.Tensor, target: torch.Tensor):
    M = input.shape[0]
    N = input.shape[1]
    output = torch.empty(target.shape, dtype = torch.float32, device = DEVICE)
    grid_loss = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']), )
    loss_kernel[grid_loss](input, target, output, M, N)
    return output

full code: https://colab.research.google.com/drive/1-gjk14MHtiP5Dc36mvqNWpFyONOmn54j?usp=sharing

Environment details

Triton: 3.1.0
GPU: Tesla T4
Platform: Google colab

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant