From c02f9763589624224e441caf6375bf39d2bf75c2 Mon Sep 17 00:00:00 2001
From: AG Ramesh <ag.ramesh@intel.com>
Date: Fri, 30 Aug 2019 13:24:19 -0700
Subject: [PATCH] Add support for Addv2

---
 tensorflow/core/graph/mkl_graph_util.h        |  1 +
 tensorflow/core/graph/mkl_layout_pass.cc      |  6 ++
 tensorflow/core/graph/mkl_layout_pass_test.cc | 59 +++++++++++++++++++
 .../core/kernels/mkl_cwise_ops_common.cc      |  2 +
 tensorflow/core/ops/math_ops.cc               | 28 +++++++--
 5 files changed, 90 insertions(+), 6 deletions(-)

diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
index 2ea20d01225..ba2526b2c35 100644
--- a/tensorflow/core/graph/mkl_graph_util.h
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -206,6 +206,7 @@ static inline bool IsMklElementWiseOp(const string& op_name, DataType T) {
     return false;
   }
   bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
+                 0 == op_name.compare(GetMklOpName("AddV2")) ||
                  0 == op_name.compare(GetMklOpName("Sub")) ||
                  0 == op_name.compare(GetMklOpName("Mul")) ||
                  0 == op_name.compare(GetMklOpName("Maximum")) ||
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 1487200b4e3..8cb65926a4f 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -350,6 +350,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
     // in the MklUtil.h (IsMklElementWiseOp method) to ensure that the
     // MklInputConversion op is added before it.
     csinfo_.add = "Add";
+    csinfo_.add_v2 = "AddV2";
     csinfo_.maximum = "Maximum";
     csinfo_.mul = "Mul";
     csinfo_.squared_difference = "SquaredDifference";
@@ -364,6 +365,10 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
     rinfo_.push_back({csinfo_.add, mkl_op_registry::GetMklOpName(csinfo_.add),
                       CopyAttrsAll, RewriteIfAtleastOneMklInput,
                       kRewriteForLayoutPropagation});
+    rinfo_.push_back({csinfo_.add_v2,
+                      mkl_op_registry::GetMklOpName(csinfo_.add_v2),
+                      CopyAttrsAll, RewriteIfAtleastOneMklInput,
+                      kRewriteForLayoutPropagation});
     rinfo_.push_back(
         {csinfo_.avg_pool, mkl_op_registry::GetMklOpName(csinfo_.avg_pool),
          CopyAttrsAll, AlwaysRewrite, kRewriteForLayoutPropagation});
@@ -869,6 +874,7 @@ rinfo_.push_back({csinfo_.tanh_grad,
   typedef struct {
     string addn;
     string add;
+    string add_v2;
     string avg_pool;
     string avg_pool_grad;
     string avg_pool3d;
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index b69a30e8274..9d57f1eddfb 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -3776,6 +3776,65 @@ TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) {
             "B->D:1;C->D:2;D->E:1");
 }
 
+// The following positive and negative tests test the rewrite of Add and Addv2
+// to MKL versions. The operators will be rewritten only if one of the inputs
+// comes from another MKL operator.
+TEST_F(MklLayoutPassTest, PositiveRewriteAdd) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'M' op: 'Relu'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " input: ['A']}"
+      "node { name: 'N' op: 'Add'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " input: ['M', 'B']}");
+  EXPECT_EQ(
+      DoMklLayoutOptimizationPass(),
+      "A(Input);B(Input);DMT/_0(Const);DMT/_1(Const);M(_MklRelu);N(_MklAdd)"
+      "|A->M;A:control->DMT/_0:control;B->N:1;DMT/_0->M:1;DMT/_1->N:3;M->N;"
+      "M:1->N:2;M:control->DMT/_1:control");
+}
+
+TEST_F(MklLayoutPassTest, NegativeRewriteAdd) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'N' op: 'Add'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " input: ['A', 'B']}");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);N(Add)|A->N;B->N:1");
+}
+
+TEST_F(MklLayoutPassTest, PositiveRewriteAddV2) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'M' op: 'Relu'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " input: ['A']}"
+      "node { name: 'N' op: 'AddV2'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " input: ['M', 'B']}");
+  EXPECT_EQ(
+      DoMklLayoutOptimizationPass(),
+      "A(Input);B(Input);DMT/_0(Const);DMT/_1(Const);M(_MklRelu);N(_MklAddV2)"
+      "|A->M;A:control->DMT/_0:control;B->N:1;DMT/_0->M:1;DMT/_1->N:3;M->N;"
+      "M:1->N:2;M:control->DMT/_1:control");
+}
+
+TEST_F(MklLayoutPassTest, NegativeRewriteAddV2) {
+  InitGraph(
+      "node { name: 'A' op: 'Input'}"
+      "node { name: 'B' op: 'Input'}"
+      "node { name: 'N' op: 'AddV2'"
+      " attr { key: 'T'                value { type: DT_FLOAT } }"
+      " input: ['A', 'B']}");
+  EXPECT_EQ(DoMklLayoutOptimizationPass(),
+            "A(Input);B(Input);N(AddV2)|A->N;B->N:1");
+}
+
 /////////////////////////////////////////////////////////////////////
 //         Post-rewrite fixup pass test
 /////////////////////////////////////////////////////////////////////
diff --git a/tensorflow/core/kernels/mkl_cwise_ops_common.cc b/tensorflow/core/kernels/mkl_cwise_ops_common.cc
index 9c6a9c48bab..e332d530e3a 100644
--- a/tensorflow/core/kernels/mkl_cwise_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_cwise_ops_common.cc
@@ -70,6 +70,8 @@ class MklBinaryOp : public BinaryOp<Device, Functor> {
 
 REGISTER6(MklBinaryOp, CPU, "_MklAdd", functor::add, float, Eigen::half, double,
           int32, int64, bfloat16);
+REGISTER6(MklBinaryOp, CPU, "_MklAddV2", functor::add, float, Eigen::half,
+          double, int32, int64, bfloat16);
 REGISTER8(MklBinaryOp, CPU, "_MklSub", functor::sub, float, Eigen::half, double,
           int32, int64, complex64, complex128, bfloat16);
 REGISTER6(MklBinaryOp, CPU, "_MklMul", functor::mul, float, Eigen::half, double,
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 5d25e92bae0..b453a8a534d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -66,8 +66,8 @@ REGISTER_OP("AddN")
           } else if (shapes_and_types && shapes_and_types_i) {
             if (shapes_and_types_i->size() != shapes_and_types->size()) {
               return errors::InvalidArgument(
-                  "shapes_and_types[", i,
-                  "].size() == ", shapes_and_types_i->size(),
+                  "shapes_and_types[", i, "].size() == ",
+                  shapes_and_types_i->size(),
                   " != shapes_and_types[0].size() == ",
                   shapes_and_types->size());
             }
@@ -395,6 +395,7 @@ REGISTER_OP("AddV2")
     .SetIsAggregate()
     .SetIsCommutative();
 
+#ifdef INTEL_MKL
 REGISTER_OP("_MklAdd")
     .Input("x: T")
     .Input("y: T")
@@ -413,6 +414,21 @@ Returns `x` + `y` element-wise.
 [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
 )doc");
 
+REGISTER_OP("_MklAddV2")
+    .Input("x: T")
+    .Input("y: T")
+    .Input("mkl_x: uint8")
+    .Input("mkl_y: uint8")
+    .Output("z: T")
+    .Output("mkl_z: uint8")
+    .Attr(
+        "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
+        "complex64, complex128}")
+    .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+    .SetIsAggregate()
+    .SetIsCommutative();
+#endif  // INTEL_MKL
+
 REGISTER_OP("Sub").BINARY_MORE().SetShapeFn(
     shape_inference::BroadcastBinaryOpShapeFn);
 
@@ -1366,12 +1382,12 @@ Status RangeSize(const Tensor* start_t, const Tensor* limit_t,
   T limit = limit_t->scalar<T>()();
   T delta = delta_t->scalar<T>()();
   if (start > limit && delta > 0) {
-    return errors::InvalidArgument(
-        "Requires start <= limit when delta > 0: ", start, "/", limit);
+    return errors::InvalidArgument("Requires start <= limit when delta > 0: ",
+                                   start, "/", limit);
   }
   if (start < limit && delta < 0) {
-    return errors::InvalidArgument(
-        "Requires start >= limit when delta < 0: ", start, "/", limit);
+    return errors::InvalidArgument("Requires start >= limit when delta < 0: ",
+                                   start, "/", limit);
   }
   if (delta == 0) {
     return errors::InvalidArgument("Requires delta != 0");