Remove .SetIsCommutative() from ops that are not actually commutative w.r.t. all their inputs.

PiperOrigin-RevId: 244423062
This commit is contained in:
A. Unique TensorFlower 2019-04-19 14:28:52 -07:00 committed by TensorFlower Gardener
parent ee82131dbc
commit 75156b5981
2 changed files with 16 additions and 11 deletions

View File

@ -27,6 +27,13 @@ REGISTER_OP("Invert")
.SetShapeFn(shape_inference::UnchangedShape);
#define BINARY_BITWISE() \
Input("x: T") \
.Input("y: T") \
.Output("z: T") \
.Attr("T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64}") \
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
#define BINARY_BITWISE_COMMUTATIVE() \
Input("x: T") \
.Input("y: T") \
.Output("z: T") \
@ -40,11 +47,11 @@ REGISTER_OP("PopulationCount")
.Attr("T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("BitwiseAnd").BINARY_BITWISE();
REGISTER_OP("BitwiseAnd").BINARY_BITWISE_COMMUTATIVE();
REGISTER_OP("BitwiseOr").BINARY_BITWISE();
REGISTER_OP("BitwiseOr").BINARY_BITWISE_COMMUTATIVE();
REGISTER_OP("BitwiseXor").BINARY_BITWISE();
REGISTER_OP("BitwiseXor").BINARY_BITWISE_COMMUTATIVE();
REGISTER_OP("LeftShift").BINARY_BITWISE();

View File

@ -447,15 +447,14 @@ REGISTER_OP("MulNoNan")
.Input("y: T")
.Output("z: T")
.Attr("T: {half, float, double, complex64, complex128}")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
// Note: This op is not commutative w.r.t. to all its inputs.
REGISTER_OP("_MklMul")
.BINARY_MORE()
.Input("mkl_x: uint8")
.Input("mkl_y: uint8")
.Output("mkl_z: uint8")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Returns x * y element-wise.
@ -490,12 +489,12 @@ REGISTER_OP("SquaredDifference")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
// Note: This op is not commutative w.r.t. to all its inputs.
REGISTER_OP("_MklSquaredDifference")
.BINARY_FEWER()
.Input("mkl_x: uint8")
.Input("mkl_y: uint8")
.Output("mkl_z: uint8")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Returns (x - y)(x - y) element-wise.
@ -529,6 +528,7 @@ REGISTER_OP("Maximum")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
// Note: This op is not commutative w.r.t. to all its inputs.
REGISTER_OP("_MklMaximum")
.Input("x: T")
.Input("y: T")
@ -537,7 +537,6 @@ REGISTER_OP("_MklMaximum")
.Output("z: T")
.Output("mkl_z: uint8")
.Attr("T: {half, float, double, int32, int64, bfloat16}")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Returns the max of x and y (i.e. x > y ? x : y) element-wise.
@ -1619,6 +1618,7 @@ REGISTER_OP("QuantizedMatMul")
return Status::OK();
});
// Note: This op is not commutative w.r.t. to all its inputs.
REGISTER_OP("QuantizedMul")
.Input("x: T1")
.Input("y: T2")
@ -1632,7 +1632,6 @@ REGISTER_OP("QuantizedMul")
.Attr("T1: quantizedtype")
.Attr("T2: quantizedtype")
.Attr("Toutput: quantizedtype = DT_QINT32")
.SetIsCommutative()
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
c->set_output(1, c->Scalar());
@ -1640,6 +1639,7 @@ REGISTER_OP("QuantizedMul")
return Status::OK();
});
// Note: This op is not commutative w.r.t. to all its inputs.
REGISTER_OP("QuantizedAdd")
.Input("x: T1")
.Input("y: T2")
@ -1653,7 +1653,6 @@ REGISTER_OP("QuantizedAdd")
.Attr("T1: quantizedtype")
.Attr("T2: quantizedtype")
.Attr("Toutput: quantizedtype = DT_QINT32")
.SetIsCommutative()
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
// min_x, max_x, min_y, max_y should be scalar.
@ -1770,6 +1769,7 @@ REGISTER_OP("ClipByValue")
.SetShapeFn(shape_inference::UnchangedShape);
#ifdef INTEL_MKL
// Note: This op is not commutative w.r.t. to all its inputs.
REGISTER_OP("_MklAddN")
.Input("inputs: N * T")
.Input("mkl_input: N * uint8")
@ -1777,8 +1777,6 @@ REGISTER_OP("_MklAddN")
.Output("mkl_sum: uint8")
.Attr("N: int >= 1")
.Attr("T: numbertype")
.SetIsCommutative()
.SetIsAggregate()
.SetShapeFn([](InferenceContext* c) {
ShapeHandle cur = c->input(c->num_inputs() - 1);
for (int i = c->num_inputs() - 2; i >= 0; --i) {