diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 37ee8d38f53..9bade548849 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -528,6 +528,17 @@ class AssertAllCloseTest(test.TestCase): x = check_ops.assert_near(t1, t2) assert x is None + @test_util.run_in_graph_and_eager_modes + def test_doesnt_raise_complex(self): + x = constant_op.constant(1. + 0.1j, name="x") + y = constant_op.constant(1.1 + 0.1j, name="y") + with ops.control_dependencies([ + check_ops.assert_near( + x, y, atol=0., rtol=0.5, message="failure message") + ]): + out = array_ops.identity(x) + self.evaluate(out) + class AssertLessTest(test.TestCase): diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index b50313753d6..680796df48d 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -828,12 +828,15 @@ def assert_near( x = ops.convert_to_tensor(x, name='x') y = ops.convert_to_tensor(y, name='y', dtype=x.dtype) - eps = np.finfo(x.dtype.as_numpy_dtype).eps + dtype = x.dtype + if dtype.is_complex: + dtype = dtype.real_dtype + eps = np.finfo(dtype.as_numpy_dtype).eps rtol = 10 * eps if rtol is None else rtol atol = 10 * eps if atol is None else atol - rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype) - atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype) + rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype) + atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype) if context.executing_eagerly(): x_name = _shape_and_dtype_str(x)