Add test cases of uint16, uint32, uint64 support for tf.math.[equal|not_equal]

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-04-06 22:53:09 +00:00
parent 67ef513917
commit 42b0014216
1 changed files with 25 additions and 0 deletions

View File

@ -948,6 +948,31 @@ class ComparisonOpTest(test.TestCase):
"Incompatible shapes|Dimensions must be equal"): "Incompatible shapes|Dimensions must be equal"):
f(x.astype(t), y.astype(t)) f(x.astype(t), y.astype(t))
def testEqualDType(self):
dtypes = [np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64]
x = np.asarray([0, 1, 2, 3, 4])
y = np.asarray([0, 1, 2, 3, 4])
for dtype in dtypes:
xt = x.astype(dtype)
yt = y.astype(dtype)
cmp_eq = math_ops.equal(xt, yt)
cmp_ne = math_ops.not_equal(xt, yt)
values = self.evaluate([cmp_eq, cmp_ne])
self.assertAllEqual(
[[True, True, True, True, True],
[False, False, False, False, False]], values)
for dtype in [np.complex64, np.complex128]:
xt = x.astype(dtype)
xt -= 1j * xt
yt = y.astype(dtype)
yt -= 1j * yt
cmp_eq = math_ops.equal(xt, yt)
cmp_ne = math_ops.not_equal(xt, yt)
values = self.evaluate([cmp_eq, cmp_ne])
self.assertAllEqual(
[[True, True, True, True, True],
[False, False, False, False, False]], values)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()