Add qint8/qint16/quint8/quint16 support for tf.math.equal/tf/math.not_equal
This PR add qint8/qint16/quint8/quint16 support for tf.math.equal/tf/math.not_equal, as was requested in 26069. Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
7eb5039543
commit
3a7368d06f
@ -18,7 +18,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, float, Eigen::half, double,
|
||||
uint8, int8, int16, bfloat16);
|
||||
REGISTER3(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64);
|
||||
REGISTER7(BinaryOp, CPU, "Equal", functor::equal_to, uint16, uint32, uint64,
|
||||
qint8, qint16, quint8, quint16);
|
||||
REGISTER_KERNEL_BUILDER(
|
||||
Name("ApproximateEqual").Device(DEVICE_CPU).TypeConstraint<float>("T"),
|
||||
ApproximateEqualOp<CPUDevice, float>);
|
||||
|
@ -18,8 +18,8 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
|
||||
double, uint8, int8, int16, bfloat16);
|
||||
REGISTER3(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32,
|
||||
uint64);
|
||||
REGISTER7(BinaryOp, CPU, "NotEqual", functor::not_equal_to, uint16, uint32,
|
||||
uint64, qint8, qint16, quint8, quint16);
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
REGISTER4(BinaryOp, GPU, "NotEqual", functor::not_equal_to, float, Eigen::half,
|
||||
double, uint8);
|
||||
|
@ -708,10 +708,7 @@ REGISTER_OP("GreaterEqual").COMPARISON();
|
||||
.Input("y: T") \
|
||||
.Output("z: bool") \
|
||||
.SetIsCommutative() \
|
||||
.Attr( \
|
||||
"T: {bfloat16, half, float, double, uint8, int8, int16, int32, " \
|
||||
"int64, uint16, uint32, uint64, complex64, " \
|
||||
"quint8, qint8, qint32, string, bool, complex128}") \
|
||||
.Attr("T: type") \
|
||||
.Attr("incompatible_shape_error: bool = true") \
|
||||
.SetShapeFn([](InferenceContext* c) { \
|
||||
ShapeHandle x = c->input(0); \
|
||||
|
@ -991,6 +991,25 @@ class ComparisonOpTest(test.TestCase):
|
||||
[[True, True, True, True, True], [False, False, False, False, False]],
|
||||
values)
|
||||
|
||||
def testEqualQuantizeDType(self):
|
||||
dtypes = [
|
||||
dtypes_lib.qint8,
|
||||
dtypes_lib.qint16,
|
||||
dtypes_lib.quint8,
|
||||
dtypes_lib.quint16,
|
||||
]
|
||||
x = np.asarray([0, 1, 2, 3, 4])
|
||||
y = np.asarray([0, 1, 2, 3, 4])
|
||||
for dtype in dtypes:
|
||||
xt = x.astype(dtype.as_numpy_dtype)
|
||||
yt = y.astype(dtype.as_numpy_dtype)
|
||||
cmp_eq = math_ops.equal(xt, yt)
|
||||
cmp_ne = math_ops.not_equal(xt, yt)
|
||||
values = self.evaluate([cmp_eq, cmp_ne])
|
||||
self.assertAllEqual(
|
||||
[[True, True, True, True, True], [False, False, False, False, False]],
|
||||
values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user