Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assertion failure in LinearLayoutConversions on H100s when num_warps=8 #5609

Open
Rifur13 opened this issue Jan 14, 2025 · 1 comment
Open
Assignees
Labels

Comments

@Rifur13
Copy link

Rifur13 commented Jan 14, 2025

Describe the bug

I'm getting many errors related to Linear Layouts when num_wraps=8.

  1. After commit e57b468
python: triton/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp:1008:  
mlir::triton::LinearLayout mlir::triton::gpu::{anonymous}::chooseStMatrixLayoutLeadingOffset(
mlir::MLIRContext*, mlir::RankedTensorType, int): Assertion `instrN >= numColsPerChunk && 
"Each chunk is filled in with a single warp"' failed.
  1. Before commit e57b468
python: triton/lib/Tools/LinearLayout.cpp:503: mlir::triton::LinearLayout 
mlir::triton::LinearLayout::reshapeOuts(llvm::ArrayRef<std::pair<mlir::StringAttr, int> >) 
const: Assertion `getTotalOutDimSize() == std::accumulate( newOutDims.begin(), newOutDims.end(), 1, 
[&](int32_t acc, auto &outDim) { return acc * outDim.second; })' failed.

The 2nd error looks similar to #5265, so I wonder if it was only a partial fix.

Python reproducer:

@triton.jit
def repro_kernel(q_ref,
                 k_ref,
                 v_ref,
                 output_ptr,
                 ):
    offsets64 = tl.arange(0, 64)
    offsets32 = tl.arange(0, 32)
    q = tl.load(q_ref + (offsets64[:, None] * 32 + offsets32[None, :]))
    k = tl.load(k_ref + (offsets32[:, None] * 64 + offsets64[None, :]))
    v = tl.load(v_ref + (offsets64[:, None] * 32 + offsets32[None, :]))

    qk = tl.dot(q, k).to(tl.bfloat16)
    o = tl.dot(qk.T, v)
    tl.store(output_ptr + (offsets64[:, None] * 32 + offsets32[None, :]), o.to(tl.bfloat16))


def repro(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
    output = torch.empty((64, 32), dtype=torch.bfloat16, device='cuda')
    grid = lambda meta: (1, 1)
    k = repro_kernel[grid](q, k, v, output, num_warps=8, num_ctas=1, num_stages=1)
    return output

torch.manual_seed(0)
q = torch.randn((64, 32), dtype=torch.bfloat16, device='cuda')
k = torch.randn((32, 64), dtype=torch.bfloat16, device='cuda')
v = torch.randn((64, 32), dtype=torch.bfloat16, device='cuda')

out = repro(q, k, v)
out_ref = (q @ k).T @ v

assert torch.allclose(out, out_ref)

Environment details

H100 with commit fc88ce5

@Rifur13 Rifur13 added the bug label Jan 14, 2025
@Jokeren
Copy link
Contributor

Jokeren commented Jan 14, 2025

Thanks for the reproducer. I'll just suggest disabling stmatrix for this code path for now.

Might be better waiting for all ldmatrix/stmatrix refactor work to be done

@Jokeren Jokeren self-assigned this Jan 14, 2025
@Jokeren Jokeren moved this to Todo in Linear Layout Jan 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Todo
Development

No branches or pull requests

2 participants