Merge pull request #42173 from Intel-tensorflow:yunfeimao/matmul_tanh_fusion
PiperOrigin-RevId: 331209886 Change-Id: Ic244af866e6c6cc0967558b51e3e146ee6c3abd3
This commit is contained in:
commit
cccbb3dd4b
@ -581,6 +581,8 @@ bool IsSymbolicGradient(const NodeDef& node) {
|
||||
return node.op() == "SymbolicGradient";
|
||||
}
|
||||
|
||||
bool IsTanh(const NodeDef& node) { return node.op() == "Tanh"; }
|
||||
|
||||
bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
|
||||
|
||||
bool IsTensorArray(const NodeDef& node) {
|
||||
|
@ -189,6 +189,7 @@ bool IsSub(const NodeDef& node);
|
||||
bool IsSum(const NodeDef& node);
|
||||
bool IsSwitch(const NodeDef& node);
|
||||
bool IsSymbolicGradient(const NodeDef& node);
|
||||
bool IsTanh(const NodeDef& node);
|
||||
bool IsTanhGrad(const NodeDef& node);
|
||||
bool IsTensorArray(const NodeDef& node);
|
||||
bool IsTile(const NodeDef& node);
|
||||
|
@ -365,12 +365,12 @@ bool IsDeviceCompatible(const RemapperContext& ctx, Pattern& matched) {
|
||||
}
|
||||
|
||||
bool IsSupportedActivation(const NodeDef& node) {
|
||||
// Disable LeakyRelu temporarily before MKL PR is merged.
|
||||
#ifndef INTEL_MKL
|
||||
return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node);
|
||||
#ifdef INTEL_MKL
|
||||
return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsTanh(node);
|
||||
#else
|
||||
return IsRelu(node) || IsRelu6(node) || IsElu(node);
|
||||
#endif // !INTEL_MKL
|
||||
// Disable LeakyRelu temporarily before MKL PR is merged.
|
||||
return IsRelu(node) || IsRelu6(node) || IsElu(node) || IsLeakyRelu(node);
|
||||
#endif
|
||||
}
|
||||
|
||||
inline bool HasControlFaninOrFanout(const utils::MutableNodeView& node_view) {
|
||||
@ -464,6 +464,9 @@ bool FindContractionWithBiasAndActivation(
|
||||
bias_add_node_view->GetRegularFanin(0).node_view();
|
||||
const auto* contraction_node_def = contraction_node_view->node();
|
||||
|
||||
// Currently, only matmul + bias + tanh is enable
|
||||
if (!IsMatMul(*contraction_node_def) && IsTanh(*node_def)) return false;
|
||||
|
||||
// Currently, only conv + bias + leakyrelu is enabled
|
||||
if (!IsConv2D(*contraction_node_def) && IsLeakyRelu(*node_def)) return false;
|
||||
|
||||
@ -716,6 +719,9 @@ bool FindContractionWithBiasAndAddActivation(
|
||||
if (node_def == nullptr) return false;
|
||||
if (!IsSupportedActivation(*node_def)) return false;
|
||||
|
||||
// Currently, Contraction + Bias + Add + Tanh pattern is not supported
|
||||
if (IsTanh(*node_def)) return false;
|
||||
|
||||
#ifdef ENABLE_INTEL_MKL_BFLOAT16
|
||||
// MKL activation op only supports float and bfloat16 data types.
|
||||
if (!HasDataType(node_def, DT_FLOAT) && !HasDataType(node_def, DT_BFLOAT16))
|
||||
|
@ -416,6 +416,7 @@ tf_cc_test_mkl(
|
||||
"//tensorflow/core/kernels:pad_op",
|
||||
"//tensorflow/core/kernels:relu_op",
|
||||
"//tensorflow/core/kernels/image:image",
|
||||
"//tensorflow/core:tensorflow",
|
||||
] + MKL_TEST_DEPS,
|
||||
)
|
||||
|
||||
|
@ -876,6 +876,12 @@ class MklFusedMatMulOpTest : public OpsTestBase {
|
||||
next_op = ops::Elu(root.WithOpName(last_op), next_op);
|
||||
}
|
||||
|
||||
if (std::find(fused_ops.begin(), fused_ops.end(), "Tanh") !=
|
||||
fused_ops.end()) {
|
||||
last_op = "with_tanh";
|
||||
next_op = ops::Tanh(root.WithOpName(last_op), next_op);
|
||||
}
|
||||
|
||||
CommonTestUtilities<T>::RunAndFetch(root, last_op, output);
|
||||
};
|
||||
|
||||
@ -963,11 +969,21 @@ TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndElu) {
|
||||
{"BiasAdd", "Elu"});
|
||||
}
|
||||
|
||||
TYPED_TEST_P(MklFusedMatMulOpTest, WithBiasAndTanh) {
|
||||
const int batch = 3;
|
||||
const int input_channel = 4;
|
||||
const int output_channel = 5;
|
||||
|
||||
this->VerifyFusedMatMul(batch, input_channel, output_channel,
|
||||
{"BiasAdd", "Tanh"});
|
||||
}
|
||||
|
||||
REGISTER_TYPED_TEST_SUITE_P(MklFusedMatMulOpTest, //
|
||||
WithBias, //
|
||||
WithBiasAndRelu, //
|
||||
WithBiasAndRelu6, //
|
||||
WithBiasAndElu);
|
||||
WithBiasAndElu, //
|
||||
WithBiasAndTanh);
|
||||
|
||||
using MklFusedMatMulDataTypes = ::testing::Types<float>;
|
||||
INSTANTIATE_TYPED_TEST_SUITE_P(Test, MklFusedMatMulOpTest,
|
||||
|
@ -226,6 +226,8 @@ class MklFusedMatMulOp : public MklDnnMatMulOpBase<T, T> {
|
||||
params.post_op_params.push_back({"relu6", {1.0, 6.0, 0.0}});
|
||||
} else if (post_op == "Elu") {
|
||||
params.post_op_params.push_back({"elu", {1.0, 1.0, 0.0}});
|
||||
} else if (post_op == "Tanh") {
|
||||
params.post_op_params.push_back({"tanh", {1.0, 0.0, 0.0}});
|
||||
} else {
|
||||
OP_REQUIRES_OK(
|
||||
ctx, errors::InvalidArgument(
|
||||
|
@ -247,6 +247,13 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
||||
float op_beta = post_op_param.param[2];
|
||||
post_ops.append_eltwise(op_scale, ALGORITHM::eltwise_elu, op_alpha,
|
||||
op_beta);
|
||||
} else if (post_op_param.name == "tanh") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 3);
|
||||
float op_scale = post_op_param.param[0];
|
||||
float op_alpha = post_op_param.param[1];
|
||||
float op_beta = post_op_param.param[2];
|
||||
post_ops.append_eltwise(op_scale, ALGORITHM::eltwise_tanh, op_alpha,
|
||||
op_beta);
|
||||
} else if (post_op_param.name == "output_scale") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 1);
|
||||
std::vector<float> scales;
|
||||
@ -256,6 +263,7 @@ class MklDnnMatMulFwdPrimitive : public MklPrimitive {
|
||||
DCHECK((post_op_param.name == "relu") ||
|
||||
(post_op_param.name == "relu6") ||
|
||||
(post_op_param.name == "elu") ||
|
||||
(post_op_param.name == "tanh") ||
|
||||
(post_op_param.name == "output_scale"));
|
||||
}
|
||||
}
|
||||
@ -359,11 +367,12 @@ class MklDnnMatMulFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.bias_dims);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dst_dims);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.dtypes);
|
||||
key_creator.AddAsKey(mkldnn_matmul_fwd_dims.weight_format);
|
||||
|
||||
// Generate keys for post-ops
|
||||
for (auto const& post_op_param : mkldnn_matmul_fwd_dims.post_op_params) {
|
||||
if (post_op_param.name == "relu" || post_op_param.name == "relu6" ||
|
||||
post_op_param.name == "elu") {
|
||||
post_op_param.name == "elu" || post_op_param.name == "tanh") {
|
||||
DCHECK_EQ(post_op_param.param.size(), 3);
|
||||
key_creator.AddAsKey(post_op_param.name);
|
||||
key_creator.AddAsKey(post_op_param.param[0]);
|
||||
|
Loading…
x
Reference in New Issue
Block a user