Merge pull request #39825 from yongtang:39815-tf.debugging.assert_near-complex
PiperOrigin-RevId: 313402180 Change-Id: Iedbffb47293e8315919723069b82ed5c9c91cdfd
This commit is contained in:
commit
788647217d
@ -528,6 +528,17 @@ class AssertAllCloseTest(test.TestCase):
|
|||||||
x = check_ops.assert_near(t1, t2)
|
x = check_ops.assert_near(t1, t2)
|
||||||
assert x is None
|
assert x is None
|
||||||
|
|
||||||
|
@test_util.run_in_graph_and_eager_modes
|
||||||
|
def test_doesnt_raise_complex(self):
|
||||||
|
x = constant_op.constant(1. + 0.1j, name="x")
|
||||||
|
y = constant_op.constant(1.1 + 0.1j, name="y")
|
||||||
|
with ops.control_dependencies([
|
||||||
|
check_ops.assert_near(
|
||||||
|
x, y, atol=0., rtol=0.5, message="failure message")
|
||||||
|
]):
|
||||||
|
out = array_ops.identity(x)
|
||||||
|
self.evaluate(out)
|
||||||
|
|
||||||
|
|
||||||
class AssertLessTest(test.TestCase):
|
class AssertLessTest(test.TestCase):
|
||||||
|
|
||||||
|
@ -828,12 +828,15 @@ def assert_near(
|
|||||||
x = ops.convert_to_tensor(x, name='x')
|
x = ops.convert_to_tensor(x, name='x')
|
||||||
y = ops.convert_to_tensor(y, name='y', dtype=x.dtype)
|
y = ops.convert_to_tensor(y, name='y', dtype=x.dtype)
|
||||||
|
|
||||||
eps = np.finfo(x.dtype.as_numpy_dtype).eps
|
dtype = x.dtype
|
||||||
|
if dtype.is_complex:
|
||||||
|
dtype = dtype.real_dtype
|
||||||
|
eps = np.finfo(dtype.as_numpy_dtype).eps
|
||||||
rtol = 10 * eps if rtol is None else rtol
|
rtol = 10 * eps if rtol is None else rtol
|
||||||
atol = 10 * eps if atol is None else atol
|
atol = 10 * eps if atol is None else atol
|
||||||
|
|
||||||
rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=x.dtype)
|
rtol = ops.convert_to_tensor(rtol, name='rtol', dtype=dtype)
|
||||||
atol = ops.convert_to_tensor(atol, name='atol', dtype=x.dtype)
|
atol = ops.convert_to_tensor(atol, name='atol', dtype=dtype)
|
||||||
|
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
x_name = _shape_and_dtype_str(x)
|
x_name = _shape_and_dtype_str(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user