Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FMAs now using a lot more registers than necessary (and often spilling) #5572

Open
vwbaker opened this issue Jan 10, 2025 · 1 comment
Open

Comments

@vwbaker
Copy link
Collaborator

vwbaker commented Jan 10, 2025

Describe the issue

We found that for small dots, triton is now using a concerning number of registers. In the example outlined below for a [16,256]x[256,16] dot, compiling it down through triton and then running it through ptxas used 34 registers. Now as of the culprit commit, it uses 255 registers and spills: 120 bytes stack frame, 116 bytes spill stores, 116 bytes spill loads.

Any thoughts as to what's causing this and how we can fix it?

Culprit commit: d9facf3

Here is a sample of the TTIR:

module {
  tt.func @gemm_fusion_r_1_impl(%arg0: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i8> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}) {
    %cst = arith.constant dense<5> : tensor<256x1xi32>
    %cst_0 = arith.constant dense<5> : tensor<1x256xi32>
    %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x16xf32>
    %cst_2 = arith.constant dense<0.000000e+00> : tensor<16x256xf32>
    %c1_i64 = arith.constant 1 : i64
    %c0_i32 = arith.constant 0 : i32
    %c5_i64 = arith.constant 5 : i64
    %c16_i32 = arith.constant 16 : i32
    %c1_i32 = arith.constant 1 : i32
    %c8_i32 = arith.constant 8 : i32
    %cst_3 = arith.constant dense<0.000000e+00> : tensor<16x16xf32>
    %0 = tt.get_program_id x : i32
    %1 = arith.divsi %0, %c8_i32 : i32
    %2 = arith.muli %1, %c8_i32 : i32
    %3 = arith.subi %c1_i32, %2 : i32
    %4 = arith.cmpi slt, %3, %c8_i32 : i32
    %5 = arith.select %4, %3, %c8_i32 : i32
    %6 = arith.remsi %0, %5 : i32
    %7 = arith.addi %2, %6 : i32
    %8 = arith.remsi %0, %c8_i32 : i32
    %9 = arith.divsi %8, %5 : i32
    %10 = arith.muli %7, %c16_i32 : i32
    %11 = tt.make_tensor_ptr %arg0, [%c5_i64, %c5_i64], [%c5_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x256xi8>>
    %12 = tt.advance %11, [%10, %c0_i32] : <tensor<16x256xi8>>
    %13 = arith.muli %9, %c16_i32 : i32
    %14 = tt.make_tensor_ptr %arg1, [%c5_i64, %c5_i64], [%c1_i64, %c5_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x16xi8>>
    %15 = tt.advance %14, [%c0_i32, %13] : <tensor<256x16xi8>>
    %16 = tt.load %12 {boundaryCheck = array<i32: 0, 1>, padding = 1 : i32} : !tt.ptr<tensor<16x256xi8>>
    %17 = arith.trunci %16 : tensor<16x256xi8> to tensor<16x256xi1>
    %18 = tt.load %15 {boundaryCheck = array<i32: 0, 1>, padding = 1 : i32} : !tt.ptr<tensor<256x16xi8>>
    %19 = arith.trunci %18 : tensor<256x16xi8> to tensor<256x16xi1>
    %20 = arith.uitofp %17 : tensor<16x256xi1> to tensor<16x256xf32>
    %21 = arith.uitofp %19 : tensor<256x16xi1> to tensor<256x16xf32>
    %22 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32>
    %23 = tt.expand_dims %22 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32>
    %24 = arith.cmpi slt, %23, %cst_0 : tensor<1x256xi32>
    %25 = tt.broadcast %24 : tensor<1x256xi1> -> tensor<16x256xi1>
    %26 = arith.select %25, %20, %cst_2 : tensor<16x256xi1>, tensor<16x256xf32>
    %27 = tt.expand_dims %22 {axis = 1 : i32} : tensor<256xi32> -> tensor<256x1xi32>
    %28 = arith.cmpi slt, %27, %cst : tensor<256x1xi32>
    %29 = tt.broadcast %28 : tensor<256x1xi1> -> tensor<256x16xi1>
    %30 = arith.select %29, %21, %cst_1 : tensor<256x16xi1>, tensor<256x16xf32>
    %31 = tt.dot %26, %30, %cst_3 : tensor<16x256xf32> * tensor<256x16xf32> -> tensor<16x16xf32>
    %32 = tt.make_tensor_ptr %arg2, [%c5_i64, %c5_i64], [%c5_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<16x16xf32>>
    %33 = tt.advance %32, [%10, %13] : <tensor<16x16xf32>>
    tt.store %33, %31 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<16x16xf32>>
    tt.return
  }
}

Steps to reproduce: compile through triton, get the ptx from .triton/cache/ and run that file through ptxas -arch=sm_80 -v --warn-on-spills.

Environment details

GPU: A100 (also appears on H100)

Triton version: affects triton built from source after d9facf3

@vwbaker
Copy link
Collaborator Author

vwbaker commented Jan 10, 2025

@binarman who authored #5469

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant