From 7738aca0dcf9f2d2d27b7c3bb1b17c0fb41bbb10 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 23 May 2020 22:56:38 +0000 Subject: [PATCH] Add complex tensor support for tf.debugging.assert_near This PR tries to address the issue raised in 39815 where tf.debugging.assert_near does not support complex tensors as was specified in docstring. This PR adds complex tensor support for tf.debugging.assert_near. This PR fixes 39815. Signed-off-by: Yong Tang --- tensorflow/python/kernel_tests/check_ops_test.py | 10 ++++++++++ tensorflow/python/ops/check_ops.py | 9 ++++++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 47f392d7438..6a1b5c1f952 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -528,6 +528,16 @@ class AssertAllCloseTest(test.TestCase): x = check_ops.assert_near(t1, t2) assert x is None + @test_util.run_in_graph_and_eager_modes + def test_doesnt_raise_complex(self): + x = constant_op.constant(1. + 0.1j, name="x") + y = constant_op.constant(1.1 + 0.1j, name="y") + with ops.control_dependencies( + [check_ops.assert_near(x, y, atol=0., rtol=0.5, + message="failure message")]): + out = array_ops.identity(x) + self.evaluate(out) + class AssertLessTest(test.TestCase): diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index 3085e05eaf6..c1a17bc13ab 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -812,12 +812,15 @@ def assert_near( x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y', dtype=x.dtype) - eps = np.finfo(x.dtype.as_numpy_dtype).eps + dtype = x.dtype + if dtype.is_complex: + dtype = dtype.real_dtype + eps = np.finfo(dtype.as_numpy_dtype).eps rtol = 10 * eps if rtol is None else rtol atol = 10 * eps if atol is None else atol - rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype) - atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype) + rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype) + atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype) if context.executing_eagerly(): x_name = _shape_and_dtype_str(x)