Merge pull request #42173 from Intel-tensorflow:yunfeimao/matmul_tanh_fusion

PiperOrigin-RevId: 331209886
Change-Id: Ic244af866e6c6cc0967558b51e3e146ee6c3abd3
This commit is contained in:
TensorFlower Gardener 2020-09-11 13:38:01 -07:00
commit cccbb3dd4b
7 changed files with 44 additions and 7 deletions

View File

@ -581,6 +581,8 @@ bool IsSymbolicGradient(const NodeDef& node) {
return node.op() == "SymbolicGradient";
}
bool IsTanh(const NodeDef& node) { return node.op() == "Tanh"; }
bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
bool IsTensorArray(const NodeDef& node) {

View File

@ -189,6 +189,7 @@ bool IsSub(const NodeDef& node);
bool IsSum(const NodeDef& node);
bool IsSwitch(const NodeDef& node);
bool IsSymbolicGradient(const NodeDef& node);
bool IsTanh(const NodeDef& node);
bool IsTanhGrad(const NodeDef& node);
bool IsTensorArray(const NodeDef& node);
bool IsTile(const NodeDef& node);

View File

@ -365,12 +365,12 @@ bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) {
}
bool IsSupportedActivation(const NodeDef& node) {
// Disable LeakyRelu temporarily before MKL PR is merged.
#ifndef INTEL_MKL
return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node);
#ifdef INTEL_MKL
return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsTanh(node);
#else
return IsRelu(node) || IsRelu6(node) || IsElu(node);
#endif // !INTEL_MKL
// Disable LeakyRelu temporarily before MKL PR is merged.
return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node);
#endif
}
inline bool HasControlFaninOrFanout(const utils::MutableNodeView& node_view) {
@ -464,6 +464,9 @@ bool FindContractionWithBiasAndActivation(
bias_add_node_view->GetRegularFanin(0).node_view();
const auto* contraction_node_def = contraction_node_view->node();
// Currently, only matmul + bias + tanh is enable
if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false;
// Currently, only conv + bias + leakyrelu is enabled
if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
@ -716,6 +719,9 @@ bool FindContractionWithBiasAndAddActivation(
if (node_def == nullptr) return false;
if (!IsSupportedActivation(*node_def)) return false;
// Currently, Contraction + Bias + Add + Tanh pattern is not supported
if (IsTanh(*node_def)) return false;
#ifdef ENABLE_INTEL_MKL_BFLOAT16
// MKL activation op only supports float and bfloat16 data types.
if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16))

View File

@ -416,6 +416,7 @@ tf_cc_test_mkl(
"//tensorflow/core/kernels:pad_op",
"//tensorflow/core/kernels:relu_op",
"//tensorflow/core/kernels/image:image",
"//tensorflow/core:tensorflow",
] + MKL_TEST_DEPS,
)

View File

@ -876,6 +876,12 @@ class MklFusedMatMulOpTest : public OpsTestBase {
next_op = ops::Elu(root.WithOpName(last_op), next_op);
}
if (std::find(fused_ops.begin(), fused_ops.end(), "Tanh") !=
fused_ops.end()) {
last_op = "with_tanh";
next_op = ops::Tanh(root.WithOpName(last_op), next_op);
}
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
};
@ -963,11 +969,21 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndElu) {
{"BiasAdd", "Elu"});
}
TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndTanh) {
const int batch = 3;
const int input_channel = 4;
const int output_channel = 5;
this->VerifyFusedMatMul(batch, input_channel, output_channel,
{"BiasAdd", "Tanh"});
}
REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest, //
WithBias, //
WithBiasAndRelu, //
WithBiasAndRelu6, //
WithBiasAndElu);
WithBiasAndElu, //
WithBiasAndTanh);
using MklFusedMatMulDataTypes = ::testing::Types<float>;
INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest,

View File

@ -226,6 +226,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
params.post_op_params.push_back({"relu6", {1.0, 6.0, 0.0}});
} else if (post_op == "Elu") {
params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}});
} else if (post_op == "Tanh") {
params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}});
} else {
OP_REQUIRES_OK(
ctx, errors::InvalidArgument(

View File

@ -247,6 +247,13 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
float op_beta = post_op_param.param[2];
post_ops.append_eltwise(op_scale, ALGORITHM::eltwise_elu, op_alpha,
op_beta);
} else if (post_op_param.name == "tanh") {
DCHECK_EQ(post_op_param.param.size(), 3);
float op_scale = post_op_param.param[0];
float op_alpha = post_op_param.param[1];
float op_beta = post_op_param.param[2];
post_ops.append_eltwise(op_scale, ALGORITHM::eltwise_tanh, op_alpha,
op_beta);
} else if (post_op_param.name == "output_scale") {
DCHECK_EQ(post_op_param.param.size(), 1);
std::vector<float> scales;
@ -256,6 +263,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
DCHECK((post_op_param.name == "relu") ||
(post_op_param.name == "relu6") ||
(post_op_param.name == "elu") ||
(post_op_param.name == "tanh") ||
(post_op_param.name == "output_scale"));
}
}
@ -359,11 +367,12 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes);
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_format);
// Generate keys for post-ops
for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
if (post_op_param.name == "relu" || post_op_param.name == "relu6" ||
post_op_param.name == "elu") {
post_op_param.name == "elu" || post_op_param.name == "tanh") {
DCHECK_EQ(post_op_param.param.size(), 3);
key_creator.AddAsKey(post_op_param.name);
key_creator.AddAsKey(post_op_param.param[0]);