Merge pull request #41795 from yongtang:26069-equal-qint8-qint16-quint8-quint16
PiperOrigin-RevId: 333868852 Change-Id: I584cbb81c18e596af4dd363b1f16432260d432f1
This commit is contained in:
commit
065f6f6b59
tensorflow
compiler/mlir/tensorflow/ir
core
python/kernel_tests
@ -3229,8 +3229,8 @@ tf.math.equal(x, y) ==> array([True, True])
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
||||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
|
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
|
||||||
|
|
||||||
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
||||||
);
|
);
|
||||||
@ -7083,8 +7083,8 @@ def TF_NotEqualOp : TF_Op<"NotEqual", [Commutative, NoSideEffect]> {
|
|||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins
|
let arguments = (ins
|
||||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$x,
|
||||||
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
|
TensorOf<[TF_Bfloat16, TF_Bool, TF_Complex128, TF_Complex64, TF_Float16, TF_Float32, TF_Float64, TF_Int16, TF_Int32, TF_Int64, TF_Int8, TF_Qint16, TF_Quint16, TF_Qint32, TF_Qint8, TF_Quint8, TF_Str, TF_Uint16, TF_Uint32, TF_Uint64, TF_Uint8]>:$y,
|
||||||
|
|
||||||
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
DefaultValuedAttr<BoolAttr, "true">:$incompatible_shape_error
|
||||||
);
|
);
|
||||||
|
@ -18,7 +18,8 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
|
REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
|
||||||
uint8, int8, int16, bfloat16);
|
uint8, int8, int16, bfloat16);
|
||||||
REGISTER3(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64);
|
REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64,
|
||||||
|
qint8, qint16, quint8, quint16);
|
||||||
REGISTER_KERNEL_BUILDER(
|
REGISTER_KERNEL_BUILDER(
|
||||||
Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<float>("T"),
|
Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<float>("T"),
|
||||||
ApproximateEqualOp<CPUDevice, float>);
|
ApproximateEqualOp<CPUDevice, float>);
|
||||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
|
REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
|
||||||
double, uint8, int8, int16, bfloat16);
|
double, uint8, int8, int16, bfloat16);
|
||||||
REGISTER3(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32,
|
REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32,
|
||||||
uint64);
|
uint64, qint8, qint16, quint8, quint16);
|
||||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||||
REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
|
REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
|
||||||
double, uint8);
|
double, uint8);
|
||||||
|
@ -703,27 +703,24 @@ REGISTER_OP("GreaterEqual").COMPARISON();
|
|||||||
|
|
||||||
// --------------------------------------------------------------------------
|
// --------------------------------------------------------------------------
|
||||||
|
|
||||||
#define EQUALITY_COMPARISON() \
|
#define EQUALITY_COMPARISON() \
|
||||||
Input("x: T") \
|
Input("x: T") \
|
||||||
.Input("y: T") \
|
.Input("y: T") \
|
||||||
.Output("z: bool") \
|
.Output("z: bool") \
|
||||||
.SetIsCommutative() \
|
.SetIsCommutative() \
|
||||||
.Attr( \
|
.Attr("T: type") \
|
||||||
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
|
.Attr("incompatible_shape_error: bool = true") \
|
||||||
"int64, uint16, uint32, uint64, complex64, " \
|
.SetShapeFn([](InferenceContext* c) { \
|
||||||
"quint8, qint8, qint32, string, bool, complex128}") \
|
ShapeHandle x = c->input(0); \
|
||||||
.Attr("incompatible_shape_error: bool = true") \
|
ShapeHandle y = c->input(1); \
|
||||||
.SetShapeFn([](InferenceContext* c) { \
|
ShapeHandle output; \
|
||||||
ShapeHandle x = c->input(0); \
|
bool incompatible_shape_error; \
|
||||||
ShapeHandle y = c->input(1); \
|
TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error", \
|
||||||
ShapeHandle output; \
|
&incompatible_shape_error)); \
|
||||||
bool incompatible_shape_error; \
|
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( \
|
||||||
TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error", \
|
c, x, y, incompatible_shape_error, &output)); \
|
||||||
&incompatible_shape_error)); \
|
c->set_output(0, output); \
|
||||||
TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( \
|
return Status::OK(); \
|
||||||
c, x, y, incompatible_shape_error, &output)); \
|
|
||||||
c->set_output(0, output); \
|
|
||||||
return Status::OK(); \
|
|
||||||
})
|
})
|
||||||
|
|
||||||
REGISTER_OP("Equal").EQUALITY_COMPARISON();
|
REGISTER_OP("Equal").EQUALITY_COMPARISON();
|
||||||
|
@ -991,6 +991,25 @@ class ComparisonOpTest(test.TestCase):
|
|||||||
[[True, True, True, True, True], [False, False, False, False, False]],
|
[[True, True, True, True, True], [False, False, False, False, False]],
|
||||||
values)
|
values)
|
||||||
|
|
||||||
|
def testEqualQuantizeDType(self):
|
||||||
|
dtypes = [
|
||||||
|
dtypes_lib.qint8,
|
||||||
|
dtypes_lib.qint16,
|
||||||
|
dtypes_lib.quint8,
|
||||||
|
dtypes_lib.quint16,
|
||||||
|
]
|
||||||
|
x = np.asarray([0, 1, 2, 3, 4])
|
||||||
|
y = np.asarray([0, 1, 2, 3, 4])
|
||||||
|
for dtype in dtypes:
|
||||||
|
xt = x.astype(dtype.as_numpy_dtype)
|
||||||
|
yt = y.astype(dtype.as_numpy_dtype)
|
||||||
|
cmp_eq = math_ops.equal(xt, yt)
|
||||||
|
cmp_ne = math_ops.not_equal(xt, yt)
|
||||||
|
values = self.evaluate([cmp_eq, cmp_ne])
|
||||||
|
self.assertAllEqual(
|
||||||
|
[[True, True, True, True, True], [False, False, False, False, False]],
|
||||||
|
values)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user