Merge pull request #39825 from yongtang:39815-tf.debugging.assert_near-complex

PiperOrigin-RevId: 313402180
Change-Id: Iedbffb47293e8315919723069b82ed5c9c91cdfd
This commit is contained in:
TensorFlower Gardener 2020-05-27 09:23:29 -07:00
commit 788647217d
2 changed files with 17 additions and 3 deletions

View File

@ -528,6 +528,17 @@ class AssertAllCloseTest(test.TestCase):
x = check_ops.assert_near(t1, t2) x = check_ops.assert_near(t1, t2)
assert x is None assert x is None
@test_util.run_in_graph_and_eager_modes
def test_doesnt_raise_complex(self):
x = constant_op.constant(1. + 0.1j, name="x")
y = constant_op.constant(1.1 + 0.1j, name="y")
with ops.control_dependencies([
check_ops.assert_near(
x, y, atol=0., rtol=0.5, message="failure message")
]):
out = array_ops.identity(x)
self.evaluate(out)
class AssertLessTest(test.TestCase): class AssertLessTest(test.TestCase):

View File

@ -828,12 +828,15 @@ def assert_near(
x = ops.convert_to_tensor(x, name='x') x = ops.convert_to_tensor(x, name='x')
y = ops.convert_to_tensor(y, name='y', dtype=x.dtype) y = ops.convert_to_tensor(y, name='y', dtype=x.dtype)
eps = np.finfo(x.dtype.as_numpy_dtype).eps dtype = x.dtype
if dtype.is_complex:
dtype = dtype.real_dtype
eps = np.finfo(dtype.as_numpy_dtype).eps
rtol = 10 * eps if rtol is None else rtol rtol = 10 * eps if rtol is None else rtol
atol = 10 * eps if atol is None else atol atol = 10 * eps if atol is None else atol
rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype) rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype)
atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype) atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype)
if context.executing_eagerly(): if context.executing_eagerly():
x_name = _shape_and_dtype_str(x) x_name = _shape_and_dtype_str(x)