From 199d79d32e068efbb22031d1063b6f651f5a5f8d Mon Sep 17 00:00:00 2001 From: AG Ramesh Date: Fri, 30 Aug 2019 13:24:19 -0700 Subject: [PATCH] Add support for Addv2 --- tensorflow/core/graph/mkl_graph_util.h | 1 + tensorflow/core/graph/mkl_layout_pass.cc | 6 ++ tensorflow/core/graph/mkl_layout_pass_test.cc | 59 +++++++++++++++++++ .../core/kernels/mkl_cwise_ops_common.cc | 2 + tensorflow/core/ops/math_ops.cc | 16 +++++ 5 files changed, 84 insertions(+) diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index cb4afabcb07..1cce1bc425b 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -206,6 +206,7 @@ static inline bool IsMklElementWiseOp(const string& op_name, DataType T) { return false; } bool result = (0 == op_name.compare(GetMklOpName("Add")) || + 0 == op_name.compare(GetMklOpName("AddV2")) || 0 == op_name.compare(GetMklOpName("Sub")) || 0 == op_name.compare(GetMklOpName("Mul")) || 0 == op_name.compare(GetMklOpName("Maximum")) || diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 41f519e8de2..287ff2a6750 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -350,6 +350,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the // MklInputConversion op is added before it. csinfo_.add = "Add"; + csinfo_.add_v2 = "AddV2"; csinfo_.maximum = "Maximum"; csinfo_.mul = "Mul"; csinfo_.squared_difference = "SquaredDifference"; @@ -364,6 +365,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), CopyAttrsAll, RewriteIfAtleastOneMklInput, kRewriteForLayoutPropagation}); + rinfo_.push_back({csinfo_.add_v2, + mkl_op_registry::GetMklOpName(csinfo_.add_v2), + CopyAttrsAll, RewriteIfAtleastOneMklInput, + kRewriteForLayoutPropagation}); rinfo_.push_back( {csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); @@ -867,6 +872,7 @@ rinfo_.push_back({csinfo_.tanh_grad, typedef struct { string addn; string add; + string add_v2; string avg_pool; string avg_pool_grad; string avg_pool3d; diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index b69a30e8274..9d57f1eddfb 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -3776,6 +3776,65 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) { "B->D:1;C->D:2;D->E:1"); } +// The following positive and negative tests test the rewrite of Add and Addv2 +// to MKL versions. The operators will be rewritten only if one of the inputs +// comes from another MKL operator. +TEST_F(MklLayoutPassTest, PositiveRewriteAdd) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'M' op: 'Relu'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A']}" + "node { name: 'N' op: 'Add'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['M', 'B']}"); + EXPECT_EQ( + DoMklLayoutOptimizationPass(), + "A(Input);B(Input);DMT/_0(Const);DMT/_1(Const);M(_MklRelu);N(_MklAdd)" + "|A->M;A:control->DMT/_0:control;B->N:1;DMT/_0->M:1;DMT/_1->N:3;M->N;" + "M:1->N:2;M:control->DMT/_1:control"); +} + +TEST_F(MklLayoutPassTest, NegativeRewriteAdd) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'N' op: 'Add'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']}"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);N(Add)|A->N;B->N:1"); +} + +TEST_F(MklLayoutPassTest, PositiveRewriteAddV2) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'M' op: 'Relu'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A']}" + "node { name: 'N' op: 'AddV2'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['M', 'B']}"); + EXPECT_EQ( + DoMklLayoutOptimizationPass(), + "A(Input);B(Input);DMT/_0(Const);DMT/_1(Const);M(_MklRelu);N(_MklAddV2)" + "|A->M;A:control->DMT/_0:control;B->N:1;DMT/_0->M:1;DMT/_1->N:3;M->N;" + "M:1->N:2;M:control->DMT/_1:control"); +} + +TEST_F(MklLayoutPassTest, NegativeRewriteAddV2) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'N' op: 'AddV2'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']}"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);N(AddV2)|A->N;B->N:1"); +} + ///////////////////////////////////////////////////////////////////// // Post-rewrite fixup pass test ///////////////////////////////////////////////////////////////////// diff --git a/tensorflow/core/kernels/mkl_cwise_ops_common.cc b/tensorflow/core/kernels/mkl_cwise_ops_common.cc index 9c6a9c48bab..e332d530e3a 100644 --- a/tensorflow/core/kernels/mkl_cwise_ops_common.cc +++ b/tensorflow/core/kernels/mkl_cwise_ops_common.cc @@ -70,6 +70,8 @@ class MklBinaryOp : public BinaryOp { REGISTER6(MklBinaryOp, CPU, "_MklAdd", functor::add, float, Eigen::half, double, int32, int64, bfloat16); +REGISTER6(MklBinaryOp, CPU, "_MklAddV2", functor::add, float, Eigen::half, + double, int32, int64, bfloat16); REGISTER8(MklBinaryOp, CPU, "_MklSub", functor::sub, float, Eigen::half, double, int32, int64, complex64, complex128, bfloat16); REGISTER6(MklBinaryOp, CPU, "_MklMul", functor::mul, float, Eigen::half, double, diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 06e53517724..665b71025f3 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -397,6 +397,7 @@ REGISTER_OP("AddV2") .SetIsAggregate() .SetIsCommutative(); +#ifdef INTEL_MKL REGISTER_OP("_MklAdd") .Input("x: T") .Input("y: T") @@ -415,6 +416,21 @@ Returns `x` + `y` element-wise. [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). )doc"); +REGISTER_OP("_MklAddV2") + .Input("x: T") + .Input("y: T") + .Input("mkl_x: uint8") + .Input("mkl_y: uint8") + .Output("z: T") + .Output("mkl_z: uint8") + .Attr( + "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, " + "complex64, complex128}") + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) + .SetIsAggregate() + .SetIsCommutative(); +#endif // INTEL_MKL + REGISTER_OP("Sub").BINARY_MORE().SetShapeFn( shape_inference::BroadcastBinaryOpShapeFn);