Skip to content

Commit

Permalink
Propagate upstream Marlin kernel fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ahadnagy committed Jan 3, 2025
1 parent 66dca87 commit 40c17d6
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
3 changes: 3 additions & 0 deletions optimum/quanto/library/extensions/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ def get_max_cuda_arch():
extra_cuda_cflags = [
"--expt-extended-lambda",
"--use_fast_math",
"-lineinfo",
'-O0'
]
# We need to know the minimum CUDA Arch to select only the relevant kernels
# but we cannot rely on __CUDA_ARCH__ as it is not set in host code (only on device code)
Expand Down Expand Up @@ -187,6 +189,7 @@ def gemm_f16i4_marlin(
dtype=input.dtype,
device=input.device,
)
print(f"input shapes: {input.reshape((-1, input.shape[-1])).shape}, in2: {other.shape}, out: {output.reshape((-1, output.shape[-1])).shape}")
ext.lib.marlin_gemm_f16i4(
input.reshape((-1, input.shape[-1])),
other,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,8 +374,10 @@ __global__ void Marlin(
int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_s = sh_b + (stages * b_sh_stage);
int4* sh_red = sh_s + (stages * s_sh_stage);
// ADDED: shared memory storage for scaled zero points
int4* sh_sz = sh_s + (stages * s_sh_stage);
int4* sh_sz = sh_red + (stages * s_sh_stage);

// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2];
Expand Down Expand Up @@ -499,21 +501,21 @@ __global__ void Marlin(
for (int j = 0; j < 4 * 2; j++) {
int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
float* c_rd = reinterpret_cast<float*>(&sh_red[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh_red[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k];
}
sh[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
sh_red[red_sh_wr] = reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
}
}
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
for (int i = 0; i < 4 * 2; i++) {
float* c_rd = reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
float* c_rd = reinterpret_cast<float*>(&sh_red[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] += c_rd[j];
Expand Down Expand Up @@ -548,7 +550,7 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i],
&sh_red[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m
);
Expand All @@ -561,7 +563,7 @@ __global__ void Marlin(
for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<float*>(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += __half2float(
Expand Down Expand Up @@ -605,7 +607,7 @@ __global__ void Marlin(
half2 res = __halves2half2(__float2half(c0), __float2half(c1));
if (group_blocks == -1) // for per-column quantization we finally apply the scale here
res = __hmul2(res, s[0]);
((half2*) sh)[idx] = res;
((half2*) sh_red)[idx] = res;
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
Expand All @@ -626,7 +628,7 @@ __global__ void Marlin(
#pragma unroll
for (int i = 0; i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) {
if (c_gl_wr < c_gl_wr_end) {
C[c_gl_wr] = sh[c_sh_rd];
C[c_gl_wr] = sh_red[c_sh_rd];
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
Expand Down

0 comments on commit 40c17d6

Please sign in to comment.