Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add fuse_layer_norm_grad functor and op #10612

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 34 additions & 16 deletions oneflow/core/autograd/gradient_funcs/layer_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ limitations under the License.
#include "oneflow/core/functional/functional.h"

namespace oneflow {

DEFINE_ENV_BOOL(ONEFLOW_USE_FUSE_LAYER_NORM_GRAD, false);

namespace one {

struct LayerNormCaptureState : public AutoGradCaptureState {
Expand Down Expand Up @@ -107,22 +110,37 @@ Maybe<void> LayerNorm::Apply(const LayerNormCaptureState* ctx, const TensorTuple
std::shared_ptr<Tensor> mean = saved_tensors.at(ctx->mean_index);
std::shared_ptr<Tensor> inv_variance = saved_tensors.at(ctx->inv_variance_index);

if (ctx->has_affine) {
// Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance,
// Int64 begin_params_axis)
const auto& results =
JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis));
in_grads->at(1) = results->at(0); // For gamma.
in_grads->at(2) = results->at(1); // For beta.
}
if (ctx->x_requires_grad) {
if (ctx->scale) {
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma,
begin_norm_axis, ctx->epsilon));
} else {
in_grads->at(0) =
JUST(functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon));
if (EnvBool<ONEFLOW_USE_FUSE_LAYER_NORM_GRAD>()) {
// just for npu
CHECK(ctx->has_affine) << "LayerNorm::Apply must has_affine for NPU GPT2 test";
if (ctx->x_requires_grad) {
if (ctx->scale) {
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
*in_grads = *JUST(functional::FuseLayerNormAffineGrad(
dy, x, mean, inv_variance, gamma, begin_norm_axis, begin_params_axis, ctx->epsilon));
} else {
*in_grads = *JUST(functional::FuseLayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis,
begin_params_axis, ctx->epsilon));
}
}
} else {
if (ctx->has_affine) {
// Use LayerNormParamGrad(Tensor dy, Tensor x, Tensor mean, Tensor inv_variance,
// Int64 begin_params_axis)
const auto& results =
JUST(functional::LayerNormParamGrad(dy, x, mean, inv_variance, begin_params_axis));
in_grads->at(1) = results->at(0); // For gamma.
in_grads->at(2) = results->at(1); // For beta.
}
if (ctx->x_requires_grad) {
if (ctx->scale) {
std::shared_ptr<Tensor> gamma = saved_tensors.at(ctx->gamma_index);
in_grads->at(0) = JUST(functional::LayerNormAffineGrad(dy, x, mean, inv_variance, gamma,
begin_norm_axis, ctx->epsilon));
} else {
in_grads->at(0) = JUST(
functional::LayerNormGrad(dy, x, mean, inv_variance, begin_norm_axis, ctx->epsilon));
}
}
}
return Maybe<void>::Ok();
Expand Down
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,14 @@
signature: "Tensor (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Double epsilon) => LayerNormAffineGrad"
bind_python: False

- name: "fuse_layer_norm_grad"
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => FuseLayerNormGrad"
bind_python: False

- name: "fuse_layer_norm_affine_grad"
crazy-JiangDongHua marked this conversation as resolved.
Show resolved Hide resolved
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Tensor gamma, Int64 begin_norm_axis, Int64 begin_params_axis, Double epsilon) => FuseLayerNormAffineGrad"
bind_python: False

- name: "layer_norm_param_grad"
signature: "TensorTuple (Tensor dy, Tensor x, Tensor mean, Tensor inv_variance, Int64 begin_params_axis) => LayerNormParamGrad"
bind_python: False
Expand Down
60 changes: 60 additions & 0 deletions oneflow/core/functional/impl/nn_grad_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -983,6 +983,64 @@ class LayerNormAffineGradFunctor {
std::shared_ptr<OpExpr> op_;
};

class FuseLayerNormGradFunctor {
crazy-JiangDongHua marked this conversation as resolved.
Show resolved Hide resolved
public:
FuseLayerNormGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("fuse_layer_norm_grad")
.Input("dy")
.Input("x")
.Input("mean")
.Input("inv_variance")
.Output("dx")
.Output("gamma_diff")
.Output("beta_diff")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& mean,
const std::shared_ptr<one::Tensor>& inv_variance,
const int64_t& begin_norm_axis, const int64_t& begin_params_axis,
const double& epsilon) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon");
attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon);
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x, mean, inv_variance}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class FuseLayerNormAffineGradFunctor {
public:
FuseLayerNormAffineGradFunctor() {
op_ = CHECK_JUST(one::OpBuilder("fuse_layer_norm_grad")
.Input("dy")
.Input("x")
.Input("mean")
.Input("inv_variance")
.Input("gamma")
.Output("dx")
.Output("gamma_diff")
.Output("beta_diff")
.Build());
}
Maybe<TensorTuple> operator()(const std::shared_ptr<one::Tensor>& dy,
const std::shared_ptr<one::Tensor>& x,
const std::shared_ptr<one::Tensor>& mean,
const std::shared_ptr<one::Tensor>& inv_variance,
const std::shared_ptr<one::Tensor>& gamma,
const int64_t& begin_norm_axis, const int64_t& begin_params_axis,
const double& epsilon) const {
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("begin_norm_axis", "begin_params_axis", "epsilon");
attrs.SetAllAttrs(begin_norm_axis, begin_params_axis, epsilon);
return OpInterpUtil::Dispatch<TensorTuple>(*op_, {dy, x, mean, inv_variance, gamma}, attrs);
}

private:
std::shared_ptr<OpExpr> op_;
};

class LayerNormParamGradFunctor {
public:
LayerNormParamGradFunctor() {
Expand Down Expand Up @@ -1707,6 +1765,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::LayerNormGradFunctor>("LayerNormGrad");
m.add_functor<impl::LayerNormAffineGradFunctor>("LayerNormAffineGrad");
m.add_functor<impl::LayerNormParamGradFunctor>("LayerNormParamGrad");
m.add_functor<impl::FuseLayerNormGradFunctor>("FuseLayerNormGrad");
m.add_functor<impl::FuseLayerNormAffineGradFunctor>("FuseLayerNormAffineGrad");
m.add_functor<impl::GroupNormGradFunctor>("GroupNormGrad");
m.add_functor<impl::GroupNormParamGradFunctor>("GroupNormParamGrad");
m.add_functor<impl::BroadcastMatmulGradBFunctor>("BroadcastMatmulGradB");
Expand Down
29 changes: 29 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7071,6 +7071,35 @@ def OneFlow_LayerNormGradOp : OneFlow_BaseOp<"layer_norm_grad", [NoMemoryEffect,
let has_data_type_infer_fn = 1;
}

def OneFlow_FuseLayerNormGradOp : OneFlow_BaseOp<"fuse_layer_norm_grad", [NoMemoryEffect, AttrSizedOperandSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
OneFlow_Tensor:$x,
OneFlow_Tensor:$mean,
OneFlow_Tensor:$inv_variance,
Optional<OneFlow_Tensor>:$gamma,
Optional<OneFlow_Tensor>:$_add_to_output
);
let output = (outs
OneFlow_Tensor:$dx,
OneFlow_Tensor:$gamma_diff,
OneFlow_Tensor:$beta_diff
);
let attrs = (ins
DefaultValuedAttr<SI64Attr, "0">:$begin_norm_axis,
DefaultValuedAttr<SI64Attr, "0">:$begin_params_axis,
DefaultValuedAttr<F64Attr, "0.">:$epsilon
);
let trait_attrs = (ins
DenseI32ArrayAttr:$operand_segment_sizes,
DenseI32ArrayAttr:$result_segment_sizes
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_LayerNormParamGradOp : OneFlow_BaseOp<"layer_norm_param_grad", [NoMemoryEffect, AttrSizedResultSegments, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$dy,
Expand Down
40 changes: 40 additions & 0 deletions oneflow/user/kernels/fuse_layer_norm_cpu_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"

namespace oneflow {

template<typename T>
class FuseLayerNormGradCpuKernel final : public user_op::OpKernel {
public:
FuseLayerNormGradCpuKernel() = default;
~FuseLayerNormGradCpuKernel() = default;

private:
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
void Compute(user_op::KernelComputeContext* ctx) const override { TODO(); };
};

#define REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(dtype) \
REGISTER_USER_KERNEL("fuse_layer_norm_grad") \
.SetCreateFn<FuseLayerNormGradCpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCPU) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value));

REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(float)
REGISTER_FUSE_LAYER_NORM_GRAD_CPU_KERNEL(double)

} // namespace oneflow
123 changes: 123 additions & 0 deletions oneflow/user/ops/layer_norm_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,4 +268,127 @@ oneflow::DataType InferBnParamDataType(const DataType x_data_type) {
return Maybe<void>::Ok();
}

/* static */ Maybe<void> FuseLayerNormGradOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0);
const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0);
const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0);
user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0);
CHECK_EQ_OR_RETURN(dy.shape(), x.shape());
const int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
CHECK_GT_OR_RETURN(begin_norm_axis, 0);
const Shape& bn_param_shape = InferBnParamShape(x.shape(), begin_norm_axis);
CHECK_EQ_OR_RETURN(mean.shape(), bn_param_shape);
CHECK_EQ_OR_RETURN(inv_variance.shape(), bn_param_shape);
dx->set_shape(dy.shape());
dx->set_is_dynamic(dy.is_dynamic());
if (ctx->has_input("_add_to_output", 0)) {
const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0);
CHECK_EQ_OR_RETURN(add_to_output.shape(), dx->shape());
}

auto has_tensor = [ctx](const std::string& bn) -> bool {
bool ret = false;
for (const auto& t : ctx->inputs()) {
if (bn == t.first) { return true; }
}
for (const auto& t : ctx->outputs()) {
if (bn == t.first) { return true; }
}
return ret;
};
const int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis");
const bool has_beta_diff = has_tensor("beta_diff");
const bool has_gamma_diff = has_tensor("gamma_diff");
CHECK_GE_OR_RETURN(begin_params_axis, 1);
CHECK_LT_OR_RETURN(begin_params_axis, dy.shape().NumAxes());
DimVector param_shape_dim_vec;
param_shape_dim_vec.insert(param_shape_dim_vec.end(),
dy.shape().dim_vec().cbegin() + begin_params_axis,
dy.shape().dim_vec().cend());
const Shape param_shape(param_shape_dim_vec);
if (has_beta_diff) {
user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0);
beta_diff->set_shape(param_shape);
}
if (has_gamma_diff) {
user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0);
gamma_diff->set_shape(param_shape);
}
return Maybe<void>::Ok();
}

/*static*/ Maybe<void> FuseLayerNormGradOp::InferPhysicalTensorDesc(user_op::InferContext* ctx) {
return InferLogicalTensorDesc(ctx);
}

/* static */ Maybe<void> FuseLayerNormGradOp::GetSbp(user_op::SbpContext* ctx) {
std::vector<user_op::OpArg> broadcast_args;
if (ctx->user_op_conf().has_input("gamma", 0)) {
broadcast_args.emplace_back(user_op::OpArg("gamma", 0));
}
int64_t begin_norm_axis = ctx->Attr<int64_t>("begin_norm_axis");
int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis");
CHECK_EQ(begin_norm_axis, begin_params_axis)
<< "begin_norm_axis and begin_params_axis must be equal, but got " << begin_norm_axis
<< " and " << begin_params_axis;
for (int i = 0; i < begin_norm_axis; ++i) {
ctx->NewBuilder()
.Split(ctx->inputs(), i)
.Split(user_op::OpArg("dx", 0), i)
.PartialSum(user_op::OpArg("gamma_diff", 0))
.PartialSum(user_op::OpArg("beta_diff", 0))
.Broadcast(broadcast_args)
.Build();
}
return Maybe<void>::Ok();
}

/* static */ Maybe<void> FuseLayerNormGradOp::InferDataType(user_op::InferContext* ctx) {
const user_op::TensorDesc& dy = ctx->InputTensorDesc("dy", 0);
const user_op::TensorDesc& x = ctx->InputTensorDesc("x", 0);
CHECK_EQ_OR_RETURN(dy.data_type(), x.data_type())
<< "InferDataType Failed. Expected " << DataType_Name(x.data_type()) << ", but got "
<< DataType_Name(dy.data_type());
const user_op::TensorDesc& mean = ctx->InputTensorDesc("mean", 0);
const user_op::TensorDesc& inv_variance = ctx->InputTensorDesc("inv_variance", 0);
DataType bn_param_data_type = InferBnParamDataType(x.data_type());
CHECK_EQ_OR_RETURN(mean.data_type(), bn_param_data_type)
<< "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got "
<< DataType_Name(mean.data_type());
CHECK_EQ_OR_RETURN(inv_variance.data_type(), bn_param_data_type)
<< "InferDataType Failed. Expected " << DataType_Name(bn_param_data_type) << ", but got "
<< DataType_Name(inv_variance.data_type());
user_op::TensorDesc* dx = ctx->MutOutputTensorDesc("dx", 0);
dx->set_data_type(dy.data_type());
if (ctx->has_input("_add_to_output", 0)) {
const auto& add_to_output = ctx->InputTensorDesc("_add_to_output", 0);
CHECK_EQ_OR_RETURN(add_to_output.data_type(), dx->data_type())
<< "InferDataType Failed. Expected " << DataType_Name(dx->data_type()) << ", but got "
<< DataType_Name(add_to_output.data_type());
}

auto has_tensor = [ctx](const std::string& bn) -> bool {
bool ret = false;
for (auto& t : ctx->inputs()) {
if (bn == t.first) { return true; }
}
for (auto& t : ctx->outputs()) {
if (bn == t.first) { return true; }
}
return ret;
};
const bool has_beta_diff = has_tensor("beta_diff");
const bool has_gamma_diff = has_tensor("gamma_diff");
if (has_beta_diff) {
user_op::TensorDesc* beta_diff = ctx->MutOutputTensorDesc("beta_diff", 0);
beta_diff->set_data_type(dy.data_type());
}
if (has_gamma_diff) {
user_op::TensorDesc* gamma_diff = ctx->MutOutputTensorDesc("gamma_diff", 0);
gamma_diff->set_data_type(dy.data_type());
}
return Maybe<void>::Ok();
}

} // namespace oneflow