You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I implemented cross-entropy using Triton, but the performance is disappointingly low. Even after removing most of the code in the loss_kernel (producing incorrect results), the performance remains significantly worse than PyTorch. Are there common pitfalls I might have encountered, or does PyTorch apply specific optimizations that I might be missing?
Torch_lib: torch.nn.CrossEntropyLoss
torch_loss_realize1: realise CrossEntropyLoss by Pytorch
torch_loss_realize2: realise optimized CrossEntropyLoss by Pytorch
Triton: realise optimized CrossEntropyLoss by Triton
loss_kernel:
@triton.jit
def loss_kernel(input_ptr,
target_ptr,
output_ptr,
M, N,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
offsets_rows = (BLOCK_SIZE_M * pid_m + tl.arange(0, BLOCK_SIZE_M))
offsets_cols = tl.arange(0, BLOCK_SIZE_N)
target = tl.load(target_ptr + offsets_rows, mask = offsets_rows < M, other = 0.0)
max_val = tl.full(target.shape, -float("inf"), dtype = tl.float32)
sumexp = tl.zeros_like(max_val)
allcurx = tl.zeros_like(max_val)
for index in tl.static_range(0, N, BLOCK_SIZE_N):
offsets_input = offsets_rows[:, None] * N + (offsets_cols + index)[None, :]
mask_input = (offsets_rows[:, None] < M) & ((offsets_cols + index)[None, :] < N)
input_val = tl.load(input_ptr + offsets_input, mask=mask_input, other = -float("inf"))
if index == 0:
new_max_val = input_val.max(axis = 1)
else:
new_max_val = tl.maximum(max_val, input_val.max(axis = 1))
sumexp /= tl.exp(new_max_val)
sumexp *= tl.exp(max_val)
sumexp += tl.sum(tl.exp(input_val - new_max_val[:, None]), axis = 1)
max_val = new_max_val
curx = tl.load(input_ptr + offsets_rows * N + target, mask=(target >= index) & (target < index + BLOCK_SIZE_N), other = 0)
allcurx += curx
output = tl.log(sumexp) + max_val - allcurx
tl.store(output_ptr + offsets_rows, output, mask = offsets_rows < M)
def triton_loss(input: torch.Tensor, target: torch.Tensor):
M = input.shape[0]
N = input.shape[1]
output = torch.empty(target.shape, dtype = torch.float32, device = DEVICE)
grid_loss = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']), )
loss_kernel[grid_loss](input, target, output, M, N)
return output
Describe the issue
I implemented cross-entropy using Triton, but the performance is disappointingly low. Even after removing most of the code in the loss_kernel (producing incorrect results), the performance remains significantly worse than PyTorch. Are there common pitfalls I might have encountered, or does PyTorch apply specific optimizations that I might be missing?
Torch_lib
: torch.nn.CrossEntropyLosstorch_loss_realize1
: realise CrossEntropyLoss by Pytorchtorch_loss_realize2
: realise optimized CrossEntropyLoss by PytorchTriton
: realise optimized CrossEntropyLoss by Tritonloss_kernel:
full code: https://colab.research.google.com/drive/1-gjk14MHtiP5Dc36mvqNWpFyONOmn54j?usp=sharing
Environment details
Triton: 3.1.0
GPU: Tesla T4
Platform: Google colab
The text was updated successfully, but these errors were encountered: