diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h index 2ea20d01225..ba2526b2c35 100644 --- a/tensorflow/core/graph/mkl_graph_util.h +++ b/tensorflow/core/graph/mkl_graph_util.h @@ -206,6 +206,7 @@ static inline bool IsMklElementWiseOp(const string& op_name, DataType T) { return false; } bool result = (0 == op_name.compare(GetMklOpName("Add")) || + 0 == op_name.compare(GetMklOpName("AddV2")) || 0 == op_name.compare(GetMklOpName("Sub")) || 0 == op_name.compare(GetMklOpName("Mul")) || 0 == op_name.compare(GetMklOpName("Maximum")) || diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 1487200b4e3..8cb65926a4f 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -350,6 +350,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the // MklInputConversion op is added before it. csinfo_.add = "Add"; + csinfo_.add_v2 = "AddV2"; csinfo_.maximum = "Maximum"; csinfo_.mul = "Mul"; csinfo_.squared_difference = "SquaredDifference"; @@ -364,6 +365,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add), CopyAttrsAll, RewriteIfAtleastOneMklInput, kRewriteForLayoutPropagation}); + rinfo_.push_back({csinfo_.add_v2, + mkl_op_registry::GetMklOpName(csinfo_.add_v2), + CopyAttrsAll, RewriteIfAtleastOneMklInput, + kRewriteForLayoutPropagation}); rinfo_.push_back( {csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool), CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation}); @@ -869,6 +874,7 @@ rinfo_.push_back({csinfo_.tanh_grad, typedef struct { string addn; string add; + string add_v2; string avg_pool; string avg_pool_grad; string avg_pool3d; diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index b69a30e8274..9d57f1eddfb 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -3776,6 +3776,65 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) { "B->D:1;C->D:2;D->E:1"); } +// The following positive and negative tests test the rewrite of Add and Addv2 +// to MKL versions. The operators will be rewritten only if one of the inputs +// comes from another MKL operator. +TEST_F(MklLayoutPassTest, PositiveRewriteAdd) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'M' op: 'Relu'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A']}" + "node { name: 'N' op: 'Add'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['M', 'B']}"); + EXPECT_EQ( + DoMklLayoutOptimizationPass(), + "A(Input);B(Input);DMT/_0(Const);DMT/_1(Const);M(_MklRelu);N(_MklAdd)" + "|A->M;A:control->DMT/_0:control;B->N:1;DMT/_0->M:1;DMT/_1->N:3;M->N;" + "M:1->N:2;M:control->DMT/_1:control"); +} + +TEST_F(MklLayoutPassTest, NegativeRewriteAdd) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'N' op: 'Add'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']}"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);N(Add)|A->N;B->N:1"); +} + +TEST_F(MklLayoutPassTest, PositiveRewriteAddV2) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'M' op: 'Relu'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A']}" + "node { name: 'N' op: 'AddV2'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['M', 'B']}"); + EXPECT_EQ( + DoMklLayoutOptimizationPass(), + "A(Input);B(Input);DMT/_0(Const);DMT/_1(Const);M(_MklRelu);N(_MklAddV2)" + "|A->M;A:control->DMT/_0:control;B->N:1;DMT/_0->M:1;DMT/_1->N:3;M->N;" + "M:1->N:2;M:control->DMT/_1:control"); +} + +TEST_F(MklLayoutPassTest, NegativeRewriteAddV2) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Input'}" + "node { name: 'N' op: 'AddV2'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'B']}"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Input);N(AddV2)|A->N;B->N:1"); +} + ///////////////////////////////////////////////////////////////////// // Post-rewrite fixup pass test ///////////////////////////////////////////////////////////////////// diff --git a/tensorflow/core/kernels/mkl_cwise_ops_common.cc b/tensorflow/core/kernels/mkl_cwise_ops_common.cc index 9c6a9c48bab..e332d530e3a 100644 --- a/tensorflow/core/kernels/mkl_cwise_ops_common.cc +++ b/tensorflow/core/kernels/mkl_cwise_ops_common.cc @@ -70,6 +70,8 @@ class MklBinaryOp : public BinaryOp { REGISTER6(MklBinaryOp, CPU, "_MklAdd", functor::add, float, Eigen::half, double, int32, int64, bfloat16); +REGISTER6(MklBinaryOp, CPU, "_MklAddV2", functor::add, float, Eigen::half, + double, int32, int64, bfloat16); REGISTER8(MklBinaryOp, CPU, "_MklSub", functor::sub, float, Eigen::half, double, int32, int64, complex64, complex128, bfloat16); REGISTER6(MklBinaryOp, CPU, "_MklMul", functor::mul, float, Eigen::half, double, diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 5d25e92bae0..b453a8a534d 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -66,8 +66,8 @@ REGISTER_OP("AddN") } else if (shapes_and_types && shapes_and_types_i) { if (shapes_and_types_i->size() != shapes_and_types->size()) { return errors::InvalidArgument( - "shapes_and_types[", i, - "].size() == ", shapes_and_types_i->size(), + "shapes_and_types[", i, "].size() == ", + shapes_and_types_i->size(), " != shapes_and_types[0].size() == ", shapes_and_types->size()); } @@ -395,6 +395,7 @@ REGISTER_OP("AddV2") .SetIsAggregate() .SetIsCommutative(); +#ifdef INTEL_MKL REGISTER_OP("_MklAdd") .Input("x: T") .Input("y: T") @@ -413,6 +414,21 @@ Returns `x` + `y` element-wise. [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). )doc"); +REGISTER_OP("_MklAddV2") + .Input("x: T") + .Input("y: T") + .Input("mkl_x: uint8") + .Input("mkl_y: uint8") + .Output("z: T") + .Output("mkl_z: uint8") + .Attr( + "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, " + "complex64, complex128}") + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) + .SetIsAggregate() + .SetIsCommutative(); +#endif // INTEL_MKL + REGISTER_OP("Sub").BINARY_MORE().SetShapeFn( shape_inference::BroadcastBinaryOpShapeFn); @@ -1366,12 +1382,12 @@ Status RangeSize(const Tensor* start_t, const Tensor* limit_t, T limit = limit_t->scalar()(); T delta = delta_t->scalar()(); if (start > limit && delta > 0) { - return errors::InvalidArgument( - "Requires start <= limit when delta > 0: ", start, "/", limit); + return errors::InvalidArgument("Requires start <= limit when delta > 0: ", + start, "/", limit); } if (start < limit && delta < 0) { - return errors::InvalidArgument( - "Requires start >= limit when delta < 0: ", start, "/", limit); + return errors::InvalidArgument("Requires start >= limit when delta < 0: ", + start, "/", limit); } if (delta == 0) { return errors::InvalidArgument("Requires delta != 0");