From bf0c7b76fe1ca18413784ee0a51c77ff4fe3341f Mon Sep 17 00:00:00 2001 From: nhamanasu Date: Fri, 10 Jan 2025 02:14:27 +0900 Subject: [PATCH] fix mean subtraction in layer norm kernels --- src/liger_kernel/ops/layer_norm.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/liger_kernel/ops/layer_norm.py b/src/liger_kernel/ops/layer_norm.py index 6d527c7ee..9c816f671 100644 --- a/src/liger_kernel/ops/layer_norm.py +++ b/src/liger_kernel/ops/layer_norm.py @@ -57,13 +57,14 @@ def _layer_norm_forward_kernel( B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0) mean = tl.sum(X_row, axis=0) / n_cols - var = tl.sum((X_row - mean) * (X_row - mean), axis=0) / n_cols + Xmm = tl.where(mask, X_row - mean, 0) + var = tl.sum(Xmm * Xmm, axis=0) / n_cols rstd = rsqrt(var + eps) tl.store(Mean_ptr, mean) tl.store(RSTD_ptr, rstd) - Y_row = (X_row - mean) * rstd * W_row + B_row + Y_row = Xmm * rstd * W_row + B_row tl.store(Y_ptr + col_offsets, Y_row, mask=mask) @@ -118,7 +119,8 @@ def _layer_norm_backward_kernel( mean = tl.load(Mean_ptr) rstd = tl.load(RSTD_ptr) - x_hat = (x - mean) * rstd + xmm = tl.where(mask, x - mean, 0) + x_hat = xmm * rstd wdy = w * dy c1 = tl.sum(x_hat * wdy, axis=0) / n_cols c2 = tl.sum(wdy, axis=0) / n_cols