Enable conv + (bias/bn) + leakyrelu fusion
This commit is contained in:
parent
89cbc0882f
commit
4d022d6e2c
@ -334,6 +334,8 @@ bool IsImmutableConst(const NodeDef& node) {
|
||||
|
||||
bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
|
||||
|
||||
bool IsLeakyRelu(const NodeDef& node) { return node.op() == "LeakyRelu"; }
|
||||
|
||||
bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
|
||||
|
||||
bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
|
||||
|
@ -99,6 +99,7 @@ bool IsIgammac(const NodeDef& node);
|
||||
bool IsImag(const NodeDef& node);
|
||||
bool IsImmutableConst(const NodeDef& node);
|
||||
bool IsInvGrad(const NodeDef& node);
|
||||
bool IsLeakyRelu(const NodeDef& node);
|
||||
bool IsLess(const NodeDef& node);
|
||||
bool IsLessEqual(const NodeDef& node);
|
||||
bool IsLog(const NodeDef& node);
|
||||
|
@ -880,6 +880,7 @@ tf_cuda_cc_test(
|
||||
deps = [
|
||||
":remapper",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -361,7 +361,12 @@ bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) {
|
||||
}
|
||||
|
||||
bool IsSupportedActivation(const NodeDef& node) {
|
||||
// Disable LeakyRelu temporarily before MKL PR is merged.
|
||||
#ifndef INTEL_MKL
|
||||
return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node);
|
||||
#else
|
||||
return IsRelu(node) || IsRelu6(node) || IsElu(node);
|
||||
#endif // !INTEL_MKL
|
||||
}
|
||||
|
||||
inline bool HasControlFaninOrFanout(const utils::MutableNodeView& node_view) {
|
||||
@ -450,6 +455,14 @@ bool FindContractionWithBiasAndActivation(
|
||||
IsInPreserveSet(ctx, bias_add_node_def))
|
||||
return false;
|
||||
|
||||
// Get the contraction node
|
||||
const auto* contraction_node_view =
|
||||
bias_add_node_view->GetRegularFanin(0).node_view();
|
||||
const auto* contraction_node_def = contraction_node_view->node();
|
||||
|
||||
// Currently, only conv + bias + leakyrelu is enabled
|
||||
if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
|
||||
|
||||
// Check that data type and data format are supported on assigned device.
|
||||
const ContractionWithBiasAddAndActivation pattern{base.contraction,
|
||||
base.bias_add, node_index};
|
||||
@ -719,6 +732,16 @@ bool FindContractionWithBiasAndAddActivation(
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get the contraction node
|
||||
const auto* bias_add_node_view =
|
||||
add_node_view->GetRegularFanin(base.port_id).node_view();
|
||||
const auto* contraction_node_view =
|
||||
bias_add_node_view->GetRegularFanin(0).node_view();
|
||||
const auto* contraction_node_def = contraction_node_view->node();
|
||||
|
||||
// Currently, only conv + bias + add + leakyrelu is enabled
|
||||
if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
|
||||
|
||||
// We successfully found a Conv2D+BiasAdd+AddN+activation pattern.
|
||||
const ContractionWithBiasAndAddActivation pattern{
|
||||
base.contraction, base.bias_add, base.add, base.port_id, node_index};
|
||||
@ -919,7 +942,8 @@ bool FindFusedBatchNormEx(const RemapperContext& ctx, int node_index,
|
||||
return false;
|
||||
}
|
||||
|
||||
void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d) {
|
||||
void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d,
|
||||
const NodeDef* activation = nullptr) {
|
||||
DCHECK(IsConv2D(conv2d)) << "Input node must be a Conv2D";
|
||||
|
||||
auto* attr = fused_conv2d->mutable_attr();
|
||||
@ -932,10 +956,16 @@ void CopyConv2DAttributes(const NodeDef& conv2d, NodeDef* fused_conv2d) {
|
||||
(*attr)["dilations"] = src_attr.at("dilations");
|
||||
(*attr)["data_format"] = src_attr.at("data_format");
|
||||
(*attr)["use_cudnn_on_gpu"] = src_attr.at("use_cudnn_on_gpu");
|
||||
// Copy LeakyRelu's attr alpha to FusedConv2D's attr leakyrelu_alpha
|
||||
if (activation != nullptr && IsLeakyRelu(*activation)) {
|
||||
auto& activation_attr = activation->attr();
|
||||
(*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
|
||||
}
|
||||
}
|
||||
|
||||
void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d,
|
||||
NodeDef* fused_dw_conv2d) {
|
||||
NodeDef* fused_dw_conv2d,
|
||||
const NodeDef* activation = nullptr) {
|
||||
DCHECK(IsDepthwiseConv2dNative(dw_conv2d))
|
||||
<< "Input node must be a DepthwiseConv2dNative";
|
||||
|
||||
@ -947,6 +977,11 @@ void CopyDepthwiseConv2dNativeAttributes(const NodeDef& dw_conv2d,
|
||||
(*attr)["padding"] = src_attr.at("padding");
|
||||
(*attr)["dilations"] = src_attr.at("dilations");
|
||||
(*attr)["data_format"] = src_attr.at("data_format");
|
||||
// Copy LeakyRelu's attr alpha to FusedDepthwiseConv2d's attr leakyrelu_alpha
|
||||
if (activation != nullptr && IsLeakyRelu(*activation)) {
|
||||
auto& activation_attr = activation->attr();
|
||||
(*attr)["leakyrelu_alpha"] = activation_attr.at("alpha");
|
||||
}
|
||||
}
|
||||
|
||||
void CopyFusedBatchNormAttributes(const NodeDef& fused_batch_norm,
|
||||
@ -1049,6 +1084,7 @@ Status AddFusedContractionNode(
|
||||
const NodeDef& contraction = graph->node(matched.contraction);
|
||||
const NodeDef& bias_add = graph->node(matched.bias_add);
|
||||
const NodeDef& activation = graph->node(matched.activation);
|
||||
|
||||
VLOG(2) << "Fuse " << contraction.op() << " with BiasAdd and "
|
||||
<< activation.op() << ":"
|
||||
<< " activation=" << activation.name()
|
||||
@ -1064,7 +1100,8 @@ Status AddFusedContractionNode(
|
||||
|
||||
if (IsConv2D(contraction)) {
|
||||
fused_op.set_op(kFusedConv2D);
|
||||
CopyConv2DAttributes(contraction, &fused_op);
|
||||
// leaky relu has a special attribute alpha
|
||||
CopyConv2DAttributes(contraction, &fused_op, &activation);
|
||||
} else if (IsDepthwiseConv2dNative(contraction)) {
|
||||
fused_op.set_op(kFusedDepthwiseConv2dNative);
|
||||
CopyDepthwiseConv2dNativeAttributes(contraction, &fused_op);
|
||||
@ -1284,7 +1321,7 @@ Status AddFusedContractionNode(
|
||||
fused_conv2d.add_input(add.input(1 - matched.port_id));
|
||||
|
||||
CopyConv2DAttributes(contraction, &fused_conv2d);
|
||||
SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add", "Relu"}, 2);
|
||||
SetFusedOpAttributes(&fused_conv2d, {"BiasAdd", "Add", activation.op()}, 2);
|
||||
|
||||
utils::Mutation* mutation = ctx->graph_view.GetMutationBuilder();
|
||||
Status status;
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/remapper.h"
|
||||
|
||||
#include "tensorflow/cc/ops/nn_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
@ -541,7 +542,7 @@ TEST_F(RemapperTest, DISABLED_FuseConv2DWithBiasAndActivationOnGPU) {
|
||||
TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
|
||||
using ::tensorflow::ops::Placeholder;
|
||||
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = Placeholder::Shape({8, 32, 32, 3});
|
||||
@ -567,6 +568,9 @@ TEST_F(RemapperTest, FuseConv2DWithBiasAndActivation) {
|
||||
return ops::Identity(fetch, ops::Relu6(activate, bias_add));
|
||||
} else if (activation == "Elu") {
|
||||
return ops::Identity(fetch, ops::Elu(activate, bias_add));
|
||||
} else if (activation == "LeakyRelu") {
|
||||
return ops::Identity(fetch,
|
||||
ops::internal::LeakyRelu(activate, bias_add));
|
||||
}
|
||||
|
||||
return ops::Identity(fetch, bias);
|
||||
@ -795,7 +799,7 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNorm) {
|
||||
TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
|
||||
using ops::Placeholder;
|
||||
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
auto input_shape = ops::Placeholder::Shape({8, 32, 32, 3});
|
||||
@ -828,6 +832,9 @@ TEST_F(RemapperTest, FuseConv2DWithBatchNormAndActivation) {
|
||||
return ops::Identity(fetch, ops::Relu6(activate, batch_norm.y));
|
||||
} else if (activation == "Elu") {
|
||||
return ops::Identity(fetch, ops::Elu(activate, batch_norm.y));
|
||||
} else if (activation == "LeakyRelu") {
|
||||
return ops::Identity(fetch,
|
||||
ops::internal::LeakyRelu(activate, batch_norm.y));
|
||||
}
|
||||
|
||||
return ops::Identity(fetch, batch_norm.y);
|
||||
|
@ -1662,6 +1662,7 @@ tf_cuda_cc_test(
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:cc_ops_internal",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
|
@ -57,13 +57,13 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "third_party/gpus/cudnn/cudnn.h"
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_asm_opts.h"
|
||||
#include "tensorflow/stream_executor/gpu/redzone_allocator.h"
|
||||
#include "tensorflow/stream_executor/tf_allocator_adapter.h"
|
||||
#include "third_party/gpus/cudnn/cudnn.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
@ -185,14 +185,26 @@ struct LaunchFusedConv2DOp<CPUDevice, T> {
|
||||
|
||||
BiasAddArgs<T> bias_add_args;
|
||||
if (BiasAddArgs<T>::IsSupported(fusion)) {
|
||||
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
|
||||
if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) {
|
||||
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args,
|
||||
&fusion_args.leakyrelu_alpha));
|
||||
} else {
|
||||
OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
|
||||
}
|
||||
}
|
||||
|
||||
FusedBatchNormArgs<T> fused_batch_norm_args;
|
||||
if (FusedBatchNormArgs<T>::IsSupported(fusion)) {
|
||||
OP_REQUIRES_OK(context,
|
||||
InitFusedBatchNormArgs(context, fusion_args.epsilon,
|
||||
&fused_batch_norm_args));
|
||||
if (fusion == FusedComputationType::kFusedBatchNormWithLeakyRelu) {
|
||||
OP_REQUIRES_OK(context,
|
||||
InitFusedBatchNormArgs(context, fusion_args.epsilon,
|
||||
&fused_batch_norm_args,
|
||||
&fusion_args.leakyrelu_alpha));
|
||||
} else {
|
||||
OP_REQUIRES_OK(context,
|
||||
InitFusedBatchNormArgs(context, fusion_args.epsilon,
|
||||
&fused_batch_norm_args));
|
||||
}
|
||||
}
|
||||
|
||||
LaunchFusedConv2DWithOutputKernel<T> conv2d(
|
||||
@ -215,6 +227,10 @@ struct LaunchFusedConv2DOp<CPUDevice, T> {
|
||||
conv2d(WithBiasAddAndRelu6<T>(bias_add_args), context, input, filter,
|
||||
output);
|
||||
break;
|
||||
case FusedComputationType::kBiasAddWithLeakyRelu:
|
||||
conv2d(WithBiasAddAndLeakyRelu<T>(bias_add_args), context, input,
|
||||
filter, output);
|
||||
break;
|
||||
case FusedComputationType::kBiasAddWithElu:
|
||||
conv2d(WithBiasAddAndElu<T>(bias_add_args), context, input, filter,
|
||||
output);
|
||||
@ -234,6 +250,11 @@ struct LaunchFusedConv2DOp<CPUDevice, T> {
|
||||
fused_batch_norm_args),
|
||||
context, input, filter, output);
|
||||
break;
|
||||
case FusedComputationType::kFusedBatchNormWithLeakyRelu:
|
||||
conv2d(WithFusedBatchNormAndLeakyRelu<T>(fusion_args.epsilon,
|
||||
fused_batch_norm_args),
|
||||
context, input, filter, output);
|
||||
break;
|
||||
case FusedComputationType::kFusedBatchNormWithElu:
|
||||
conv2d(WithFusedBatchNormAndElu<T>(fusion_args.epsilon,
|
||||
fused_batch_norm_args),
|
||||
@ -681,10 +702,12 @@ class FusedConv2DOp : public OpKernel {
|
||||
{FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
|
||||
{FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
|
||||
{FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
|
||||
{FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
|
||||
{FCT::kFusedBatchNorm, {"FusedBatchNorm"}},
|
||||
{FCT::kFusedBatchNormWithRelu, {"FusedBatchNorm", "Relu"}},
|
||||
{FCT::kFusedBatchNormWithRelu6, {"FusedBatchNorm", "Relu6"}},
|
||||
{FCT::kFusedBatchNormWithElu, {"FusedBatchNorm", "Elu"}},
|
||||
{FCT::kFusedBatchNormWithLeakyRelu, {"FusedBatchNorm", "LeakyRelu"}},
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/cc/ops/image_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops.h"
|
||||
#include "tensorflow/cc/ops/nn_ops_internal.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/fake_input.h"
|
||||
@ -652,6 +653,8 @@ class FusedConv2DOpTest : public OpsTestBase {
|
||||
ops::Relu6(root.WithOpName("with_activation"), with_bias);
|
||||
} else if (activation_type == "Elu") {
|
||||
ops::Elu(root.WithOpName("with_activation"), with_bias);
|
||||
} else if (activation_type == "LeakyRelu") {
|
||||
ops::internal::LeakyRelu(root.WithOpName("with_activation"), with_bias);
|
||||
} else {
|
||||
ops::Identity(root.WithOpName("with_activation"), with_bias);
|
||||
}
|
||||
@ -721,6 +724,9 @@ class FusedConv2DOpTest : public OpsTestBase {
|
||||
ops::Relu6(root.WithOpName("with_activation"), with_fused_batch_norm.y);
|
||||
} else if (activation_type == "Elu") {
|
||||
ops::Elu(root.WithOpName("with_activation"), with_fused_batch_norm.y);
|
||||
} else if (activation_type == "LeakyRelu") {
|
||||
ops::internal::LeakyRelu(root.WithOpName("with_activation"),
|
||||
with_fused_batch_norm.y);
|
||||
} else {
|
||||
ops::Identity(root.WithOpName("with_activation"),
|
||||
with_fused_batch_norm.y);
|
||||
@ -1040,7 +1046,7 @@ TYPED_TEST_P(FusedConv2DWithBiasOpTest, ExplicitPaddingConvolution) {
|
||||
TYPED_TEST_P(FusedConv2DWithBiasOpTest, OneByOneConvolutionAndActivation) {
|
||||
const int filter_size = 1;
|
||||
const int filter_count = 12;
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
this->VerifyConv2DWithBiasAndActivation(activation, filter_size,
|
||||
filter_count);
|
||||
}
|
||||
@ -1049,7 +1055,7 @@ TYPED_TEST_P(FusedConv2DWithBiasOpTest, OneByOneConvolutionAndActivation) {
|
||||
TYPED_TEST_P(FusedConv2DWithBiasOpTest, ImageSizeConvolutionAndActivation) {
|
||||
const int filter_size = TestFixture::kImageWidth;
|
||||
const int filter_count = 12;
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
this->VerifyConv2DWithBiasAndActivation(activation, filter_size,
|
||||
filter_count);
|
||||
}
|
||||
@ -1058,7 +1064,7 @@ TYPED_TEST_P(FusedConv2DWithBiasOpTest, ImageSizeConvolutionAndActivation) {
|
||||
TYPED_TEST_P(FusedConv2DWithBiasOpTest, SpatialConvolutionAndActivation) {
|
||||
const int filter_size = 3;
|
||||
const int filter_count = 12;
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
this->VerifyConv2DWithBiasAndActivation(activation, filter_size,
|
||||
filter_count);
|
||||
}
|
||||
@ -1069,7 +1075,7 @@ TYPED_TEST_P(FusedConv2DWithBiasOpTest,
|
||||
ExplicitPaddingConvolutionAndActivation) {
|
||||
const int filter_size = 3;
|
||||
const int filter_count = 12;
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
this->VerifyConv2DWithBiasAndActivation(
|
||||
activation, filter_size, filter_count,
|
||||
/*explicit_paddings=*/{0, 0, 1, 2, 3, 4, 0, 0});
|
||||
@ -1112,7 +1118,7 @@ TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, ExplicitPaddingConvolution) {
|
||||
TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, OneByOneConvolutionAndActivation) {
|
||||
const int filter_size = 1;
|
||||
const int filter_count = 12;
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
this->VerifyConv2DWithBatchNormAndActivation(activation, filter_size,
|
||||
filter_count);
|
||||
}
|
||||
@ -1122,7 +1128,7 @@ TYPED_TEST_P(FusedConv2DWithBatchNormOpTest,
|
||||
ImageSizeConvolutionAndActivation) {
|
||||
const int filter_size = TestFixture::kImageWidth;
|
||||
const int filter_count = 12;
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
this->VerifyConv2DWithBatchNormAndActivation(activation, filter_size,
|
||||
filter_count);
|
||||
}
|
||||
@ -1131,7 +1137,7 @@ TYPED_TEST_P(FusedConv2DWithBatchNormOpTest,
|
||||
TYPED_TEST_P(FusedConv2DWithBatchNormOpTest, SpatialConvolutionAndActivation) {
|
||||
const int filter_size = 3;
|
||||
const int filter_count = 12;
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
this->VerifyConv2DWithBatchNormAndActivation(activation, filter_size,
|
||||
filter_count);
|
||||
}
|
||||
@ -1142,7 +1148,7 @@ TYPED_TEST_P(FusedConv2DWithBatchNormOpTest,
|
||||
ExplicitPaddingConvolutionAndActivation) {
|
||||
const int filter_size = 3;
|
||||
const int filter_count = 12;
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu"}) {
|
||||
for (const string& activation : {"Relu", "Relu6", "Elu", "LeakyRelu"}) {
|
||||
this->VerifyConv2DWithBatchNormAndActivation(
|
||||
activation, filter_size, filter_count,
|
||||
/*explicit_paddings=*/{0, 0, 1, 2, 3, 4, 0, 0});
|
||||
|
@ -60,18 +60,25 @@ Status InitializeFusedComputation(
|
||||
if (*fused_computation == FusedComputationType::kBiasAdd ||
|
||||
*fused_computation == FusedComputationType::kBiasAddWithRelu ||
|
||||
*fused_computation == FusedComputationType::kBiasAddWithRelu6 ||
|
||||
*fused_computation == FusedComputationType::kBiasAddWithElu) {
|
||||
*fused_computation == FusedComputationType::kBiasAddWithElu ||
|
||||
*fused_computation == FusedComputationType::kBiasAddWithLeakyRelu) {
|
||||
if (num_args != 1) {
|
||||
return errors::InvalidArgument(
|
||||
"Fused ", kernel_name,
|
||||
" with BiasAdd must have one extra argument: bias.");
|
||||
}
|
||||
if (*fused_computation == FusedComputationType::kBiasAddWithLeakyRelu) {
|
||||
TF_RETURN_IF_ERROR(context->GetAttr(
|
||||
"leakyrelu_alpha", &fused_computation_args->leakyrelu_alpha));
|
||||
}
|
||||
}
|
||||
|
||||
if (*fused_computation == FusedComputationType::kFusedBatchNorm ||
|
||||
*fused_computation == FusedComputationType::kFusedBatchNormWithRelu ||
|
||||
*fused_computation == FusedComputationType::kFusedBatchNormWithRelu6 ||
|
||||
*fused_computation == FusedComputationType::kFusedBatchNormWithElu) {
|
||||
*fused_computation == FusedComputationType::kFusedBatchNormWithElu ||
|
||||
*fused_computation ==
|
||||
FusedComputationType::kFusedBatchNormWithLeakyRelu) {
|
||||
if (num_args != 4) {
|
||||
return errors::InvalidArgument(
|
||||
"Fused ", kernel_name,
|
||||
@ -80,6 +87,11 @@ Status InitializeFusedComputation(
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
context->GetAttr("epsilon", &fused_computation_args->epsilon));
|
||||
if (*fused_computation ==
|
||||
FusedComputationType::kFusedBatchNormWithLeakyRelu) {
|
||||
TF_RETURN_IF_ERROR(context->GetAttr(
|
||||
"leakyrelu_alpha", &fused_computation_args->leakyrelu_alpha));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
@ -26,10 +26,10 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_FUSED_EIGEN_OUTPUT_KERNELS_H_
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_types.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -39,15 +39,18 @@ enum class FusedComputationType {
|
||||
kBiasAddWithRelu,
|
||||
kBiasAddWithRelu6,
|
||||
kBiasAddWithElu,
|
||||
kBiasAddWithLeakyRelu,
|
||||
kFusedBatchNorm,
|
||||
kFusedBatchNormWithRelu,
|
||||
kFusedBatchNormWithRelu6,
|
||||
kFusedBatchNormWithElu
|
||||
kFusedBatchNormWithElu,
|
||||
kFusedBatchNormWithLeakyRelu
|
||||
};
|
||||
|
||||
// We have to pass around additional arguments for all possible fusion types.
|
||||
struct FusedComputationArgs {
|
||||
float epsilon = 0.0; // Used by `FusedBatchNorm` fusion only
|
||||
float epsilon = 0.0; // Used by `FusedBatchNorm` fusion only
|
||||
float leakyrelu_alpha = 0.0; // Used by `LeakyRelu` fusion only
|
||||
};
|
||||
|
||||
struct FusedComputationPattern {
|
||||
@ -111,15 +114,32 @@ struct Elu {
|
||||
};
|
||||
};
|
||||
|
||||
// Applies `LeakyRelu` to the passed input expression.
|
||||
struct LeakyRelu {
|
||||
template <typename XprType>
|
||||
static auto apply(XprType expr, const float leakyrelu_alpha) -> decltype(
|
||||
(expr < std::declval<typename XprType::Scalar>())
|
||||
.select(expr *
|
||||
expr.constant(std::declval<typename XprType::Scalar>()),
|
||||
expr)) {
|
||||
return (expr < static_cast<typename XprType::Scalar>(0))
|
||||
.select(expr * expr.constant(static_cast<typename XprType::Scalar>(
|
||||
leakyrelu_alpha)),
|
||||
expr);
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BiasAddArgs {
|
||||
const T* bias_add_data = nullptr;
|
||||
float leakyrelu_alpha;
|
||||
|
||||
static bool IsSupported(FusedComputationType fusion) {
|
||||
return fusion == FusedComputationType::kBiasAdd ||
|
||||
fusion == FusedComputationType::kBiasAddWithRelu ||
|
||||
fusion == FusedComputationType::kBiasAddWithRelu6 ||
|
||||
fusion == FusedComputationType::kBiasAddWithElu;
|
||||
fusion == FusedComputationType::kBiasAddWithElu ||
|
||||
fusion == FusedComputationType::kBiasAddWithLeakyRelu;
|
||||
}
|
||||
};
|
||||
|
||||
@ -134,11 +154,14 @@ struct FusedBatchNormArgs {
|
||||
// scaling_factor = (estimated_variance + epsilon).rsqrt() * scale
|
||||
Eigen::Tensor<T, 1, Eigen::RowMajor> scaling_factor;
|
||||
|
||||
float leakyrelu_alpha;
|
||||
|
||||
static bool IsSupported(FusedComputationType fusion) {
|
||||
return fusion == FusedComputationType::kFusedBatchNorm ||
|
||||
fusion == FusedComputationType::kFusedBatchNormWithRelu ||
|
||||
fusion == FusedComputationType::kFusedBatchNormWithRelu6 ||
|
||||
fusion == FusedComputationType::kFusedBatchNormWithElu;
|
||||
fusion == FusedComputationType::kFusedBatchNormWithElu ||
|
||||
fusion == FusedComputationType::kFusedBatchNormWithLeakyRelu;
|
||||
}
|
||||
};
|
||||
|
||||
@ -203,6 +226,34 @@ struct BiasAddOutputKernel {
|
||||
const T* bias_data;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct BiasAddOutputKernel<T, LeakyRelu> {
|
||||
explicit BiasAddOutputKernel(const BiasAddArgs<T>& args)
|
||||
: bias_data(args.bias_add_data), leakyrelu_alpha(args.leakyrelu_alpha) {}
|
||||
|
||||
template <typename StorageIndex, typename Scalar>
|
||||
EIGEN_ALWAYS_INLINE void operator()(
|
||||
const ContractionOutputMapper<Scalar, StorageIndex>& output_mapper,
|
||||
const Eigen::TensorContractionParams& params, StorageIndex i,
|
||||
StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const {
|
||||
DCHECK(params.swapped_arguments);
|
||||
|
||||
const T* bias_base = bias_data + i;
|
||||
typename TTypes<T>::UnalignedConstTensor bias(bias_base, num_rows);
|
||||
|
||||
for (int col = 0; col < num_cols; ++col) {
|
||||
T* output_base = &output_mapper(0, col);
|
||||
typename TTypes<T>::UnalignedTensor output(output_base, num_rows);
|
||||
const auto expr = output + bias;
|
||||
output = LeakyRelu::template apply<decltype(expr)>(expr, leakyrelu_alpha);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
const T* bias_data;
|
||||
float leakyrelu_alpha;
|
||||
};
|
||||
|
||||
// Output kernel that fuses FusedBatchNorm operation into the output of tensor
|
||||
// contraction + activation function defined by Activation.
|
||||
template <typename T, typename Activation = Identity>
|
||||
@ -247,6 +298,51 @@ struct FusedBatchNormOutputKernel {
|
||||
const T* estimated_mean_data;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct FusedBatchNormOutputKernel<T, LeakyRelu> {
|
||||
FusedBatchNormOutputKernel(T epsilon, const FusedBatchNormArgs<T>& args)
|
||||
: epsilon(epsilon),
|
||||
scaling_factor_data(args.scaling_factor.data()),
|
||||
offset_data(args.offset_data),
|
||||
estimated_mean_data(args.estimated_mean_data),
|
||||
leakyrelu_alpha(args.leakyrelu_alpha) {}
|
||||
|
||||
template <typename StorageIndex, typename Scalar>
|
||||
EIGEN_ALWAYS_INLINE void operator()(
|
||||
const ContractionOutputMapper<Scalar, StorageIndex>& output_mapper,
|
||||
const Eigen::TensorContractionParams& params, StorageIndex i,
|
||||
StorageIndex j, StorageIndex num_rows, StorageIndex num_cols) const {
|
||||
DCHECK(params.swapped_arguments);
|
||||
|
||||
const T* scaling_factor_base = scaling_factor_data + i;
|
||||
const T* offset_base = offset_data + i;
|
||||
const T* mean_base = estimated_mean_data + i;
|
||||
|
||||
typename TTypes<T>::UnalignedConstTensor scaling_factor(scaling_factor_base,
|
||||
num_rows);
|
||||
typename TTypes<T>::UnalignedConstTensor offset(offset_base, num_rows);
|
||||
typename TTypes<T>::UnalignedConstTensor mean(mean_base, num_rows);
|
||||
|
||||
for (int col = 0; col < num_cols; ++col) {
|
||||
T* output_base = &output_mapper(0, col);
|
||||
typename TTypes<T>::UnalignedTensor output(output_base, num_rows);
|
||||
|
||||
auto scaled = (output - mean) * scaling_factor;
|
||||
auto shifted = scaled + offset;
|
||||
|
||||
output = LeakyRelu::template apply<decltype(shifted)>(shifted,
|
||||
leakyrelu_alpha);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
T epsilon;
|
||||
const T* scaling_factor_data;
|
||||
const T* offset_data;
|
||||
const T* estimated_mean_data;
|
||||
float leakyrelu_alpha;
|
||||
};
|
||||
|
||||
// Type aliases for the output kernels, purely for the sake of better launch
|
||||
// dispatching code readability.
|
||||
template <typename T>
|
||||
@ -258,6 +354,8 @@ using WithBiasAddAndRelu6 = BiasAddOutputKernel<T, Relu6>;
|
||||
template <typename T>
|
||||
using WithBiasAddAndElu = BiasAddOutputKernel<T, Elu>;
|
||||
template <typename T>
|
||||
using WithBiasAddAndLeakyRelu = BiasAddOutputKernel<T, LeakyRelu>;
|
||||
template <typename T>
|
||||
using WithFusedBatchNorm = FusedBatchNormOutputKernel<T>;
|
||||
template <typename T>
|
||||
using WithFusedBatchNormAndRelu = FusedBatchNormOutputKernel<T, Relu>;
|
||||
@ -265,9 +363,12 @@ template <typename T>
|
||||
using WithFusedBatchNormAndRelu6 = FusedBatchNormOutputKernel<T, Relu6>;
|
||||
template <typename T>
|
||||
using WithFusedBatchNormAndElu = FusedBatchNormOutputKernel<T, Elu>;
|
||||
template <typename T>
|
||||
using WithFusedBatchNormAndLeakyRelu = FusedBatchNormOutputKernel<T, LeakyRelu>;
|
||||
|
||||
template <typename T>
|
||||
Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs<T>* args) {
|
||||
Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs<T>* args,
|
||||
const float* leakyrelu_alpha = nullptr) {
|
||||
// Bias of the following dimensions: [ output_depth ]
|
||||
const Tensor& bias = context->input(2);
|
||||
|
||||
@ -281,12 +382,17 @@ Status InitBiasAddArgs(OpKernelContext* context, BiasAddArgs<T>* args) {
|
||||
|
||||
args->bias_add_data = data_ptr(bias);
|
||||
|
||||
if (leakyrelu_alpha) {
|
||||
args->leakyrelu_alpha = *leakyrelu_alpha;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon,
|
||||
FusedBatchNormArgs<T>* args) {
|
||||
FusedBatchNormArgs<T>* args,
|
||||
const float* leakyrelu_alpha = nullptr) {
|
||||
const Tensor& scale = context->input(2);
|
||||
const Tensor& offset = context->input(3);
|
||||
const Tensor& estimated_mean = context->input(4);
|
||||
@ -319,6 +425,10 @@ Status InitFusedBatchNormArgs(OpKernelContext* context, float epsilon,
|
||||
(estimated_variance.flat<T>() + static_cast<T>(epsilon)).rsqrt() *
|
||||
scale.flat<T>();
|
||||
|
||||
if (leakyrelu_alpha) {
|
||||
args->leakyrelu_alpha = *leakyrelu_alpha;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -404,6 +404,8 @@ REGISTER_OP("_FusedConv2D")
|
||||
.Attr("fused_ops: list(string) = []")
|
||||
// Attributes for the FusedBatchNorm ------------------------------------ //
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
// Attributes for the LeakyRelu ----------------------------------------- //
|
||||
.Attr("leakyrelu_alpha: float = 0.2")
|
||||
// ---------------------------------------------------------------------- //
|
||||
.SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
|
||||
.Doc(R"doc(
|
||||
@ -633,7 +635,10 @@ REGISTER_OP("_FusedDepthwiseConv2dNative")
|
||||
.Attr("fused_ops: list(string) = []")
|
||||
// Attributes for the FusedBatchNorm ------------------------------------ //
|
||||
.Attr("epsilon: float = 0.0001")
|
||||
// Attributes for the LeakyRelu ----------------------------------------- //
|
||||
.Attr("leakyrelu_alpha: float = 0.2")
|
||||
// ---------------------------------------------------------------------- //
|
||||
|
||||
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
|
||||
|
||||
// --------------------------------------------------------------------------
|
||||
|
Loading…
Reference in New Issue
Block a user