Enable 1st order gradient tests for tf.linalg.svd in eager mode.
PiperOrigin-RevId: 312756858 Change-Id: I20d73e8972014b96bc90952949820390ae77e08d
This commit is contained in:
parent
60fb5dcc7d
commit
e312350702
@ -29,6 +29,7 @@ from tensorflow.python.framework import test_util
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gradient_checker
|
from tensorflow.python.ops import gradient_checker
|
||||||
|
from tensorflow.python.ops import gradient_checker_v2
|
||||||
from tensorflow.python.ops import gradients_impl
|
from tensorflow.python.ops import gradients_impl
|
||||||
from tensorflow.python.ops import linalg_ops
|
from tensorflow.python.ops import linalg_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -225,45 +226,41 @@ def _NormalizingSvd(tf_a, full_matrices_):
|
|||||||
|
|
||||||
def _GetSvdGradOpTest(dtype_, shape_, compute_uv_, full_matrices_):
|
def _GetSvdGradOpTest(dtype_, shape_, compute_uv_, full_matrices_):
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def Test(self):
|
def Test(self):
|
||||||
|
|
||||||
|
def RandomInput():
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_)
|
a = np.random.uniform(low=-1.0, high=1.0, size=shape_).astype(dtype_)
|
||||||
if dtype_ in [np.complex64, np.complex128]:
|
if dtype_ in [np.complex64, np.complex128]:
|
||||||
a += 1j * np.random.uniform(
|
a += 1j * np.random.uniform(
|
||||||
low=-1.0, high=1.0, size=shape_).astype(dtype_)
|
low=-1.0, high=1.0, size=shape_).astype(dtype_)
|
||||||
|
return a
|
||||||
|
|
||||||
# Optimal stepsize for central difference is O(epsilon^{1/3}).
|
# Optimal stepsize for central difference is O(epsilon^{1/3}).
|
||||||
# See Equation (21) in:
|
# See Equation (21) in:
|
||||||
# http://www.karenkopecky.net/Teaching/eco613614/Notes_NumericalDifferentiation.pdf
|
# http://www.karenkopecky.net/Teaching/eco613614/Notes_NumericalDifferentiation.pdf
|
||||||
# TODO(rmlarsen): Move step size control to gradient checker.
|
# TODO(rmlarsen): Move step size control to gradient checker.
|
||||||
epsilon = np.finfo(dtype_).eps
|
epsilon = np.finfo(dtype_).eps
|
||||||
delta = 0.1 * epsilon**(1.0 / 3.0)
|
delta = 0.25 * epsilon**(1.0 / 3.0)
|
||||||
if dtype_ in [np.float32, np.complex64]:
|
if dtype_ in [np.float32, np.complex64]:
|
||||||
tol = 3e-2
|
tol = 3e-2
|
||||||
else:
|
else:
|
||||||
tol = 1e-6
|
tol = 1e-6
|
||||||
with self.session(use_gpu=True):
|
|
||||||
tf_a = constant_op.constant(a)
|
|
||||||
if compute_uv_:
|
if compute_uv_:
|
||||||
tf_s, tf_u, tf_v = _NormalizingSvd(tf_a, full_matrices_)
|
funcs = [
|
||||||
outputs = [tf_s, tf_u, tf_v]
|
lambda a: _NormalizingSvd(a, full_matrices_)[0],
|
||||||
|
lambda a: _NormalizingSvd(a, full_matrices_)[1],
|
||||||
|
lambda a: _NormalizingSvd(a, full_matrices_)[2]
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
tf_s = linalg_ops.svd(tf_a, compute_uv=False)
|
funcs = [lambda a: linalg_ops.svd(a, compute_uv=False)]
|
||||||
outputs = [tf_s]
|
|
||||||
for b in outputs:
|
for f in funcs:
|
||||||
x_init = np.random.uniform(
|
theoretical, numerical = gradient_checker_v2.compute_gradient(
|
||||||
low=-1.0, high=1.0, size=shape_).astype(dtype_)
|
f, [RandomInput()], delta=delta)
|
||||||
if dtype_ in [np.complex64, np.complex128]:
|
|
||||||
x_init += 1j * np.random.uniform(
|
|
||||||
low=-1.0, high=1.0, size=shape_).astype(dtype_)
|
|
||||||
theoretical, numerical = gradient_checker.compute_gradient(
|
|
||||||
tf_a,
|
|
||||||
tf_a.get_shape().as_list(),
|
|
||||||
b,
|
|
||||||
b.get_shape().as_list(),
|
|
||||||
x_init_value=x_init,
|
|
||||||
delta=delta)
|
|
||||||
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
|
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
|
||||||
|
|
||||||
return Test
|
return Test
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user