diff --git a/tensorflow/core/ops/bitwise_ops.cc b/tensorflow/core/ops/bitwise_ops.cc index 39acf5f358b..8d04d97fd1e 100644 --- a/tensorflow/core/ops/bitwise_ops.cc +++ b/tensorflow/core/ops/bitwise_ops.cc @@ -27,6 +27,13 @@ REGISTER_OP("Invert") .SetShapeFn(shape_inference::UnchangedShape); #define BINARY_BITWISE() \ + Input("x: T") \ + .Input("y: T") \ + .Output("z: T") \ + .Attr("T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64}") \ + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) + +#define BINARY_BITWISE_COMMUTATIVE() \ Input("x: T") \ .Input("y: T") \ .Output("z: T") \ @@ -40,11 +47,11 @@ REGISTER_OP("PopulationCount") .Attr("T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64}") .SetShapeFn(shape_inference::UnchangedShape); -REGISTER_OP("BitwiseAnd").BINARY_BITWISE(); +REGISTER_OP("BitwiseAnd").BINARY_BITWISE_COMMUTATIVE(); -REGISTER_OP("BitwiseOr").BINARY_BITWISE(); +REGISTER_OP("BitwiseOr").BINARY_BITWISE_COMMUTATIVE(); -REGISTER_OP("BitwiseXor").BINARY_BITWISE(); +REGISTER_OP("BitwiseXor").BINARY_BITWISE_COMMUTATIVE(); REGISTER_OP("LeftShift").BINARY_BITWISE(); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index dd94e66cc0f..849836b4c5d 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -447,15 +447,14 @@ REGISTER_OP("MulNoNan") .Input("y: T") .Output("z: T") .Attr("T: {half, float, double, complex64, complex128}") - .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); +// Note: This op is not commutative w.r.t. to all its inputs. REGISTER_OP("_MklMul") .BINARY_MORE() .Input("mkl_x: uint8") .Input("mkl_y: uint8") .Output("mkl_z: uint8") - .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Returns x * y element-wise. @@ -490,12 +489,12 @@ REGISTER_OP("SquaredDifference") .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); +// Note: This op is not commutative w.r.t. to all its inputs. REGISTER_OP("_MklSquaredDifference") .BINARY_FEWER() .Input("mkl_x: uint8") .Input("mkl_y: uint8") .Output("mkl_z: uint8") - .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Returns (x - y)(x - y) element-wise. @@ -529,6 +528,7 @@ REGISTER_OP("Maximum") .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); +// Note: This op is not commutative w.r.t. to all its inputs. REGISTER_OP("_MklMaximum") .Input("x: T") .Input("y: T") @@ -537,7 +537,6 @@ REGISTER_OP("_MklMaximum") .Output("z: T") .Output("mkl_z: uint8") .Attr("T: {half, float, double, int32, int64, bfloat16}") - .SetIsCommutative() .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) .Doc(R"doc( Returns the max of x and y (i.e. x > y ? x : y) element-wise. @@ -1619,6 +1618,7 @@ REGISTER_OP("QuantizedMatMul") return Status::OK(); }); +// Note: This op is not commutative w.r.t. to all its inputs. REGISTER_OP("QuantizedMul") .Input("x: T1") .Input("y: T2") @@ -1632,7 +1632,6 @@ REGISTER_OP("QuantizedMul") .Attr("T1: quantizedtype") .Attr("T2: quantizedtype") .Attr("Toutput: quantizedtype = DT_QINT32") - .SetIsCommutative() .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c)); c->set_output(1, c->Scalar()); @@ -1640,6 +1639,7 @@ REGISTER_OP("QuantizedMul") return Status::OK(); }); +// Note: This op is not commutative w.r.t. to all its inputs. REGISTER_OP("QuantizedAdd") .Input("x: T1") .Input("y: T2") @@ -1653,7 +1653,6 @@ REGISTER_OP("QuantizedAdd") .Attr("T1: quantizedtype") .Attr("T2: quantizedtype") .Attr("Toutput: quantizedtype = DT_QINT32") - .SetIsCommutative() .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c)); // min_x, max_x, min_y, max_y should be scalar. @@ -1770,6 +1769,7 @@ REGISTER_OP("ClipByValue") .SetShapeFn(shape_inference::UnchangedShape); #ifdef INTEL_MKL +// Note: This op is not commutative w.r.t. to all its inputs. REGISTER_OP("_MklAddN") .Input("inputs: N * T") .Input("mkl_input: N * uint8") @@ -1777,8 +1777,6 @@ REGISTER_OP("_MklAddN") .Output("mkl_sum: uint8") .Attr("N: int >= 1") .Attr("T: numbertype") - .SetIsCommutative() - .SetIsAggregate() .SetShapeFn([](InferenceContext* c) { ShapeHandle cur = c->input(c->num_inputs() - 1); for (int i = c->num_inputs() - 2; i >= 0; --i) {