Add test cases of uint16, uint32, uint64 support for tf.math.[equal|not_equal]
Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
parent
67ef513917
commit
42b0014216
|
@ -948,6 +948,31 @@ class ComparisonOpTest(test.TestCase):
|
|||
"Incompatible shapes|Dimensions must be equal"):
|
||||
f(x.astype(t), y.astype(t))
|
||||
|
||||
def testEqualDType(self):
|
||||
dtypes = [np.float16, np.float32, np.float64, np.int8, np.int16, np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64]
|
||||
x = np.asarray([0, 1, 2, 3, 4])
|
||||
y = np.asarray([0, 1, 2, 3, 4])
|
||||
for dtype in dtypes:
|
||||
xt = x.astype(dtype)
|
||||
yt = y.astype(dtype)
|
||||
cmp_eq = math_ops.equal(xt, yt)
|
||||
cmp_ne = math_ops.not_equal(xt, yt)
|
||||
values = self.evaluate([cmp_eq, cmp_ne])
|
||||
self.assertAllEqual(
|
||||
[[True, True, True, True, True],
|
||||
[False, False, False, False, False]], values)
|
||||
for dtype in [np.complex64, np.complex128]:
|
||||
xt = x.astype(dtype)
|
||||
xt -= 1j * xt
|
||||
yt = y.astype(dtype)
|
||||
yt -= 1j * yt
|
||||
cmp_eq = math_ops.equal(xt, yt)
|
||||
cmp_ne = math_ops.not_equal(xt, yt)
|
||||
values = self.evaluate([cmp_eq, cmp_ne])
|
||||
self.assertAllEqual(
|
||||
[[True, True, True, True, True],
|
||||
[False, False, False, False, False]], values)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
Loading…
Reference in New Issue