Remove .SetIsCommutative() from ops that are not actually commutative w.r.t. all their inputs.
PiperOrigin-RevId: 244423062
This commit is contained in:
parent
ee82131dbc
commit
75156b5981
@ -27,6 +27,13 @@ REGISTER_OP("Invert")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
#define BINARY_BITWISE() \
|
||||
Input("x: T") \
|
||||
.Input("y: T") \
|
||||
.Output("z: T") \
|
||||
.Attr("T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64}") \
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
|
||||
|
||||
#define BINARY_BITWISE_COMMUTATIVE() \
|
||||
Input("x: T") \
|
||||
.Input("y: T") \
|
||||
.Output("z: T") \
|
||||
@ -40,11 +47,11 @@ REGISTER_OP("PopulationCount")
|
||||
.Attr("T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64}")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
REGISTER_OP("BitwiseAnd").BINARY_BITWISE();
|
||||
REGISTER_OP("BitwiseAnd").BINARY_BITWISE_COMMUTATIVE();
|
||||
|
||||
REGISTER_OP("BitwiseOr").BINARY_BITWISE();
|
||||
REGISTER_OP("BitwiseOr").BINARY_BITWISE_COMMUTATIVE();
|
||||
|
||||
REGISTER_OP("BitwiseXor").BINARY_BITWISE();
|
||||
REGISTER_OP("BitwiseXor").BINARY_BITWISE_COMMUTATIVE();
|
||||
|
||||
REGISTER_OP("LeftShift").BINARY_BITWISE();
|
||||
|
||||
|
@ -447,15 +447,14 @@ REGISTER_OP("MulNoNan")
|
||||
.Input("y: T")
|
||||
.Output("z: T")
|
||||
.Attr("T: {half, float, double, complex64, complex128}")
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
|
||||
|
||||
// Note: This op is not commutative w.r.t. to all its inputs.
|
||||
REGISTER_OP("_MklMul")
|
||||
.BINARY_MORE()
|
||||
.Input("mkl_x: uint8")
|
||||
.Input("mkl_y: uint8")
|
||||
.Output("mkl_z: uint8")
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
|
||||
.Doc(R"doc(
|
||||
Returns x * y element-wise.
|
||||
@ -490,12 +489,12 @@ REGISTER_OP("SquaredDifference")
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
|
||||
|
||||
// Note: This op is not commutative w.r.t. to all its inputs.
|
||||
REGISTER_OP("_MklSquaredDifference")
|
||||
.BINARY_FEWER()
|
||||
.Input("mkl_x: uint8")
|
||||
.Input("mkl_y: uint8")
|
||||
.Output("mkl_z: uint8")
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
|
||||
.Doc(R"doc(
|
||||
Returns (x - y)(x - y) element-wise.
|
||||
@ -529,6 +528,7 @@ REGISTER_OP("Maximum")
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
|
||||
|
||||
// Note: This op is not commutative w.r.t. to all its inputs.
|
||||
REGISTER_OP("_MklMaximum")
|
||||
.Input("x: T")
|
||||
.Input("y: T")
|
||||
@ -537,7 +537,6 @@ REGISTER_OP("_MklMaximum")
|
||||
.Output("z: T")
|
||||
.Output("mkl_z: uint8")
|
||||
.Attr("T: {half, float, double, int32, int64, bfloat16}")
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
|
||||
.Doc(R"doc(
|
||||
Returns the max of x and y (i.e. x > y ? x : y) element-wise.
|
||||
@ -1619,6 +1618,7 @@ REGISTER_OP("QuantizedMatMul")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
// Note: This op is not commutative w.r.t. to all its inputs.
|
||||
REGISTER_OP("QuantizedMul")
|
||||
.Input("x: T1")
|
||||
.Input("y: T2")
|
||||
@ -1632,7 +1632,6 @@ REGISTER_OP("QuantizedMul")
|
||||
.Attr("T1: quantizedtype")
|
||||
.Attr("T2: quantizedtype")
|
||||
.Attr("Toutput: quantizedtype = DT_QINT32")
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
|
||||
c->set_output(1, c->Scalar());
|
||||
@ -1640,6 +1639,7 @@ REGISTER_OP("QuantizedMul")
|
||||
return Status::OK();
|
||||
});
|
||||
|
||||
// Note: This op is not commutative w.r.t. to all its inputs.
|
||||
REGISTER_OP("QuantizedAdd")
|
||||
.Input("x: T1")
|
||||
.Input("y: T2")
|
||||
@ -1653,7 +1653,6 @@ REGISTER_OP("QuantizedAdd")
|
||||
.Attr("T1: quantizedtype")
|
||||
.Attr("T2: quantizedtype")
|
||||
.Attr("Toutput: quantizedtype = DT_QINT32")
|
||||
.SetIsCommutative()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
|
||||
// min_x, max_x, min_y, max_y should be scalar.
|
||||
@ -1770,6 +1769,7 @@ REGISTER_OP("ClipByValue")
|
||||
.SetShapeFn(shape_inference::UnchangedShape);
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
// Note: This op is not commutative w.r.t. to all its inputs.
|
||||
REGISTER_OP("_MklAddN")
|
||||
.Input("inputs: N * T")
|
||||
.Input("mkl_input: N * uint8")
|
||||
@ -1777,8 +1777,6 @@ REGISTER_OP("_MklAddN")
|
||||
.Output("mkl_sum: uint8")
|
||||
.Attr("N: int >= 1")
|
||||
.Attr("T: numbertype")
|
||||
.SetIsCommutative()
|
||||
.SetIsAggregate()
|
||||
.SetShapeFn([](InferenceContext* c) {
|
||||
ShapeHandle cur = c->input(c->num_inputs() - 1);
|
||||
for (int i = c->num_inputs() - 2; i >= 0; --i) {
|
||||
|
Loading…
Reference in New Issue
Block a user