Add complex tensor support for tf.debugging.assert_near

This PR tries to address the issue raised in 39815 where
tf.debugging.assert_near does not support complex tensors as was specified
in docstring.

This PR adds complex tensor support for tf.debugging.assert_near.

This PR fixes 39815.

Signed-off-by: Yong Tang <yong.tang.github@outlook.com>
This commit is contained in:
Yong Tang 2020-05-23 22:56:38 +00:00
parent ae76544efc
commit 7738aca0dc
2 changed files with 16 additions and 3 deletions

View File

@ -528,6 +528,16 @@ class AssertAllCloseTest(test.TestCase):
x = check_ops.assert_near(t1, t2)
assert x is None
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_complex(self):
x = constant_op.constant(1. + 0.1j, name="x")
y = constant_op.constant(1.1 + 0.1j, name="y")
with ops.control_dependencies(
[check_ops.assert_near(x, y, atol=0., rtol=0.5,
message="failure message")]):
out = array_ops.identity(x)
self.evaluate(out)
class AssertLessTest(test.TestCase):

View File

@ -812,12 +812,15 @@ def assert_near(
x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y', dtype=x.dtype)
eps = np.finfo(x.dtype.as_numpy_dtype).eps
dtype = x.dtype
if dtype.is_complex:
dtype = dtype.real_dtype
eps = np.finfo(dtype.as_numpy_dtype).eps
rtol = 10 * eps if rtol is None else rtol
atol = 10 * eps if atol is None else atol
rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype)
atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype)
rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype)
atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype)
if context.executing_eagerly():
x_name = _shape_and_dtype_str(x)