From 75156b5981f01f5e492a6ee7642aa975a2ca67fb Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 19 Apr 2019 14:28:52 -0700
Subject: [PATCH] Remove .SetIsCommutative() from ops that are not actually
 commutative w.r.t. all their inputs.

PiperOrigin-RevId: 244423062
---
 tensorflow/core/ops/bitwise_ops.cc | 13 ++++++++++---
 tensorflow/core/ops/math_ops.cc    | 14 ++++++--------
 2 files changed, 16 insertions(+), 11 deletions(-)

diff --git a/tensorflow/core/ops/bitwise_ops.cc b/tensorflow/core/ops/bitwise_ops.cc
index 39acf5f358b..8d04d97fd1e 100644
--- a/tensorflow/core/ops/bitwise_ops.cc
+++ b/tensorflow/core/ops/bitwise_ops.cc
@@ -27,6 +27,13 @@ REGISTER_OP("Invert")
     .SetShapeFn(shape_inference::UnchangedShape);
 
 #define BINARY_BITWISE()                                                     \
+  Input("x: T")                                                              \
+      .Input("y: T")                                                         \
+      .Output("z: T")                                                        \
+      .Attr("T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64}") \
+      .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
+
+#define BINARY_BITWISE_COMMUTATIVE()                                         \
   Input("x: T")                                                              \
       .Input("y: T")                                                         \
       .Output("z: T")                                                        \
@@ -40,11 +47,11 @@ REGISTER_OP("PopulationCount")
     .Attr("T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64}")
     .SetShapeFn(shape_inference::UnchangedShape);
 
-REGISTER_OP("BitwiseAnd").BINARY_BITWISE();
+REGISTER_OP("BitwiseAnd").BINARY_BITWISE_COMMUTATIVE();
 
-REGISTER_OP("BitwiseOr").BINARY_BITWISE();
+REGISTER_OP("BitwiseOr").BINARY_BITWISE_COMMUTATIVE();
 
-REGISTER_OP("BitwiseXor").BINARY_BITWISE();
+REGISTER_OP("BitwiseXor").BINARY_BITWISE_COMMUTATIVE();
 
 REGISTER_OP("LeftShift").BINARY_BITWISE();
 
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index dd94e66cc0f..849836b4c5d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -447,15 +447,14 @@ REGISTER_OP("MulNoNan")
     .Input("y: T")
     .Output("z: T")
     .Attr("T: {half, float, double, complex64, complex128}")
-    .SetIsCommutative()
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
 
+// Note: This op is not commutative w.r.t. to all its inputs.
 REGISTER_OP("_MklMul")
     .BINARY_MORE()
     .Input("mkl_x: uint8")
     .Input("mkl_y: uint8")
     .Output("mkl_z: uint8")
-    .SetIsCommutative()
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
     .Doc(R"doc(
 Returns x * y element-wise.
@@ -490,12 +489,12 @@ REGISTER_OP("SquaredDifference")
     .SetIsCommutative()
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
 
+// Note: This op is not commutative w.r.t. to all its inputs.
 REGISTER_OP("_MklSquaredDifference")
     .BINARY_FEWER()
     .Input("mkl_x: uint8")
     .Input("mkl_y: uint8")
     .Output("mkl_z: uint8")
-    .SetIsCommutative()
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
     .Doc(R"doc(
 Returns (x - y)(x - y) element-wise.
@@ -529,6 +528,7 @@ REGISTER_OP("Maximum")
     .SetIsCommutative()
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
 
+// Note: This op is not commutative w.r.t. to all its inputs.
 REGISTER_OP("_MklMaximum")
     .Input("x: T")
     .Input("y: T")
@@ -537,7 +537,6 @@ REGISTER_OP("_MklMaximum")
     .Output("z: T")
     .Output("mkl_z: uint8")
     .Attr("T: {half, float, double, int32, int64, bfloat16}")
-    .SetIsCommutative()
     .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
     .Doc(R"doc(
 Returns the max of x and y (i.e. x > y ? x : y) element-wise.
@@ -1619,6 +1618,7 @@ REGISTER_OP("QuantizedMatMul")
       return Status::OK();
     });
 
+// Note: This op is not commutative w.r.t. to all its inputs.
 REGISTER_OP("QuantizedMul")
     .Input("x: T1")
     .Input("y: T2")
@@ -1632,7 +1632,6 @@ REGISTER_OP("QuantizedMul")
     .Attr("T1: quantizedtype")
     .Attr("T2: quantizedtype")
     .Attr("Toutput: quantizedtype = DT_QINT32")
-    .SetIsCommutative()
     .SetShapeFn([](InferenceContext* c) {
       TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
       c->set_output(1, c->Scalar());
@@ -1640,6 +1639,7 @@ REGISTER_OP("QuantizedMul")
       return Status::OK();
     });
 
+// Note: This op is not commutative w.r.t. to all its inputs.
 REGISTER_OP("QuantizedAdd")
     .Input("x: T1")
     .Input("y: T2")
@@ -1653,7 +1653,6 @@ REGISTER_OP("QuantizedAdd")
     .Attr("T1: quantizedtype")
     .Attr("T2: quantizedtype")
     .Attr("Toutput: quantizedtype = DT_QINT32")
-    .SetIsCommutative()
     .SetShapeFn([](InferenceContext* c) {
       TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
       // min_x, max_x, min_y, max_y should be scalar.
@@ -1770,6 +1769,7 @@ REGISTER_OP("ClipByValue")
     .SetShapeFn(shape_inference::UnchangedShape);
 
 #ifdef INTEL_MKL
+// Note: This op is not commutative w.r.t. to all its inputs.
 REGISTER_OP("_MklAddN")
     .Input("inputs: N * T")
     .Input("mkl_input: N * uint8")
@@ -1777,8 +1777,6 @@ REGISTER_OP("_MklAddN")
     .Output("mkl_sum: uint8")
     .Attr("N: int >= 1")
     .Attr("T: numbertype")
-    .SetIsCommutative()
-    .SetIsAggregate()
     .SetShapeFn([](InferenceContext* c) {
       ShapeHandle cur = c->input(c->num_inputs() - 1);
       for (int i = c->num_inputs() - 2; i >= 0; --i) {