Update svd_op_test to run (non-gradient) tests in eager as well as graph mode.

PiperOrigin-RevId: 311640894
Change-Id: I39b4666c461c64ffe3f33992bb536961a266abd7
This commit is contained in:
A. Unique TensorFlower 2020-05-14 17:47:38 -07:00 committed by TensorFlower Gardener
parent d5e0f468cd
commit a2ef8b5a06
2 changed files with 81 additions and 83 deletions

View File

@ -3468,7 +3468,7 @@ cuda_py_test(
name = "svd_op_test",
size = "medium",
srcs = ["svd_op_test.py"],
shard_count = 20,
shard_count = 30,
tags = [
"no_oss", # b/117185141.
"nomsan", # TODO(b/117236102): Re-enable in msan build.

View File

@ -20,8 +20,8 @@ from __future__ import print_function
import numpy as np
from tensorflow.python import tf2
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
@ -31,7 +31,7 @@ from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
@ -58,35 +58,31 @@ class SvdOpTest(test.TestCase):
"Shape must be at least rank 2 but is rank 1"):
linalg_ops.svd(vector)
@test_util.run_v1_only("b/120545219")
def testConcurrentExecutesWithoutError(self):
with self.session(use_gpu=True) as sess:
all_ops = []
for compute_uv_ in True, False:
for full_matrices_ in True, False:
matrix1 = random_ops.random_normal([5, 5], seed=42)
matrix2 = random_ops.random_normal([5, 5], seed=42)
if compute_uv_:
s1, u1, v1 = linalg_ops.svd(
matrix1, compute_uv=compute_uv_, full_matrices=full_matrices_)
s2, u2, v2 = linalg_ops.svd(
matrix2, compute_uv=compute_uv_, full_matrices=full_matrices_)
all_ops += [s1, u1, v1, s2, u2, v2]
else:
s1 = linalg_ops.svd(
matrix1, compute_uv=compute_uv_, full_matrices=full_matrices_)
s2 = linalg_ops.svd(
matrix2, compute_uv=compute_uv_, full_matrices=full_matrices_)
all_ops += [s1, s2]
val = self.evaluate(all_ops)
for i in range(2):
s = 6 * i
self.assertAllEqual(val[s], val[s + 3]) # s1 == s2
self.assertAllEqual(val[s + 1], val[s + 4]) # u1 == u2
self.assertAllEqual(val[s + 2], val[s + 5]) # v1 == v2
for i in range(2):
s = 12 + 2 * i
self.assertAllEqual(val[s], val[s + 1]) # s1 == s2
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testExecuteMultipleWithoutError(self):
all_ops = []
shape = [6, 5]
seed = [42, 24]
for compute_uv_ in True, False:
for full_matrices_ in True, False:
matrix1 = stateless_random_ops.stateless_random_normal(shape, seed)
matrix2 = stateless_random_ops.stateless_random_normal(shape, seed)
self.assertAllEqual(matrix1, matrix2)
if compute_uv_:
s1, u1, v1 = linalg_ops.svd(
matrix1, compute_uv=compute_uv_, full_matrices=full_matrices_)
s2, u2, v2 = linalg_ops.svd(
matrix2, compute_uv=compute_uv_, full_matrices=full_matrices_)
all_ops += [s1, s2, u1, u2, v1, v2]
else:
s1 = linalg_ops.svd(
matrix1, compute_uv=compute_uv_, full_matrices=full_matrices_)
s2 = linalg_ops.svd(
matrix2, compute_uv=compute_uv_, full_matrices=full_matrices_)
all_ops += [s1, s2]
val = self.evaluate(all_ops)
for i in range(0, len(val), 2):
self.assertAllEqual(val[i], val[i + 1])
def _GetSvdOpTest(dtype_, shape_, use_static_shape_, compute_uv_,
@ -136,8 +132,10 @@ def _GetSvdOpTest(dtype_, shape_, use_static_shape_, compute_uv_,
identity = array_ops.matrix_band_part(array_ops.ones_like(xx), 0, 0)
self.assertAllClose(identity, xx, atol=tol)
@test_util.run_v1_only("b/120545219")
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def Test(self):
if not use_static_shape_ and context.executing_eagerly():
return
is_complex = dtype_ in (np.complex64, np.complex128)
is_single = dtype_ in (np.float32, np.complex64)
tol = 3e-4 if is_single else 1e-12
@ -152,48 +150,48 @@ def _GetSvdOpTest(dtype_, shape_, use_static_shape_, compute_uv_,
low=-1.0, high=1.0,
size=np.prod(shape_)).reshape(shape_).astype(dtype_)
with self.session(use_gpu=True) as sess:
if use_static_shape_:
x_tf = constant_op.constant(x_np)
else:
x_tf = array_ops.placeholder(dtype_)
if use_static_shape_:
x_tf = constant_op.constant(x_np)
else:
x_tf = array_ops.placeholder(dtype_)
if compute_uv_:
s_tf, u_tf, v_tf = linalg_ops.svd(
x_tf, compute_uv=compute_uv_, full_matrices=full_matrices_)
if use_static_shape_:
s_tf_val, u_tf_val, v_tf_val = self.evaluate([s_tf, u_tf, v_tf])
else:
if compute_uv_:
s_tf, u_tf, v_tf = linalg_ops.svd(
x_tf, compute_uv=compute_uv_, full_matrices=full_matrices_)
if use_static_shape_:
s_tf_val, u_tf_val, v_tf_val = self.evaluate([s_tf, u_tf, v_tf])
else:
with self.session(use_gpu=True) as sess:
s_tf_val, u_tf_val, v_tf_val = sess.run(
[s_tf, u_tf, v_tf], feed_dict={x_tf: x_np})
else:
s_tf = linalg_ops.svd(
x_tf, compute_uv=compute_uv_, full_matrices=full_matrices_)
if use_static_shape_:
s_tf_val = self.evaluate(s_tf)
else:
s_tf = linalg_ops.svd(
x_tf, compute_uv=compute_uv_, full_matrices=full_matrices_)
if use_static_shape_:
s_tf_val = self.evaluate(s_tf)
else:
with self.session(use_gpu=True) as sess:
s_tf_val = sess.run(s_tf, feed_dict={x_tf: x_np})
if compute_uv_:
u_np, s_np, v_np = np.linalg.svd(
x_np, compute_uv=compute_uv_, full_matrices=full_matrices_)
else:
s_np = np.linalg.svd(
x_np, compute_uv=compute_uv_, full_matrices=full_matrices_)
# We explicitly avoid the situation where numpy eliminates a first
# dimension that is equal to one.
s_np = np.reshape(s_np, s_tf_val.shape)
if compute_uv_:
u_np, s_np, v_np = np.linalg.svd(
x_np, compute_uv=compute_uv_, full_matrices=full_matrices_)
else:
s_np = np.linalg.svd(
x_np, compute_uv=compute_uv_, full_matrices=full_matrices_)
# We explicitly avoid the situation where numpy eliminates a first
# dimension that is equal to one.
s_np = np.reshape(s_np, s_tf_val.shape)
CompareSingularValues(self, s_np, s_tf_val, tol)
if compute_uv_:
CompareSingularVectors(self, u_np, u_tf_val, min(shape_[-2:]), tol)
CompareSingularVectors(self,
np.conj(np.swapaxes(v_np, -2, -1)), v_tf_val,
min(shape_[-2:]), tol)
CheckApproximation(self, x_np, u_tf_val, s_tf_val, v_tf_val,
full_matrices_, tol)
CheckUnitary(self, u_tf_val, tol)
CheckUnitary(self, v_tf_val, tol)
CompareSingularValues(self, s_np, s_tf_val, tol)
if compute_uv_:
CompareSingularVectors(self, u_np, u_tf_val, min(shape_[-2:]), tol)
CompareSingularVectors(self, np.conj(np.swapaxes(v_np, -2, -1)), v_tf_val,
min(shape_[-2:]), tol)
CheckApproximation(self, x_np, u_tf_val, s_tf_val, v_tf_val,
full_matrices_, tol)
CheckUnitary(self, u_tf_val, tol)
CheckUnitary(self, v_tf_val, tol)
return Test
@ -378,15 +376,15 @@ if __name__ == "__main__":
for rows in 0, 1, 2, 5, 10, 32, 100:
for cols in 0, 1, 2, 5, 10, 32, 100:
for batch_dims in [(), (3,)] + [(3, 2)] * (max(rows, cols) < 10):
shape = batch_dims + (rows, cols)
# TF2 does not support placeholders under eager so we skip it
for use_static_shape in set([True, tf2.enabled()]):
full_shape = batch_dims + (rows, cols)
for use_static_shape in set([True, False]):
name = "%s_%s_static_shape_%s__compute_uv_%s_full_%s" % (
dtype.__name__, "_".join(map(str, shape)), use_static_shape,
compute_uv, full_matrices)
_AddTest(SvdOpTest, "Svd", name,
_GetSvdOpTest(dtype, shape, use_static_shape,
compute_uv, full_matrices))
dtype.__name__, "_".join(map(str, full_shape)),
use_static_shape, compute_uv, full_matrices)
_AddTest(
SvdOpTest, "Svd", name,
_GetSvdOpTest(dtype, full_shape, use_static_shape,
compute_uv, full_matrices))
for compute_uv in False, True:
for full_matrices in False, True:
dtypes = ([np.float32, np.float64] + [np.complex64, np.complex128] *
@ -397,16 +395,16 @@ if __name__ == "__main__":
mat_shapes += [(5, 11), (11, 5)]
for mat_shape in mat_shapes:
for batch_dims in [(), (3,)]:
shape = batch_dims + mat_shape
name = "%s_%s_compute_uv_%s_full_%s" % (
dtype.__name__, "_".join(map(str, shape)), compute_uv,
full_matrices)
_AddTest(SvdGradOpTest, "SvdGrad", name,
_GetSvdGradOpTest(dtype, shape, compute_uv, full_matrices))
full_shape = batch_dims + mat_shape
name = "%s_%s_compute_uv_%s_full_%s" % (dtype.__name__, "_".join(
map(str, full_shape)), compute_uv, full_matrices)
_AddTest(
SvdGradOpTest, "SvdGrad", name,
_GetSvdGradOpTest(dtype, full_shape, compute_uv, full_matrices))
# The results are too inaccurate for float32.
if dtype in (np.float64, np.complex128):
_AddTest(
SvdGradGradOpTest, "SvdGradGrad", name,
_GetSvdGradGradOpTest(dtype, shape, compute_uv,
_GetSvdGradGradOpTest(dtype, full_shape, compute_uv,
full_matrices))
test.main()