Enable gradient tests for tf.linalg.cholesky in eager mode.

PiperOrigin-RevId: 312723423
Change-Id: I47d52dc14638301504ef8eccf481c7d7e3a60f48
This commit is contained in:
A. Unique TensorFlower 2020-05-21 12:53:51 -07:00 committed by TensorFlower Gardener
parent 97528c3175
commit 8fdb54ea98

View File

@ -29,7 +29,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateless_random_ops
@ -37,7 +37,6 @@ from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linalg
from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
# Different gradient implementations for benchmark purposes
@ -181,7 +180,7 @@ class CholeskyOpTest(test.TestCase):
self._verifyCholesky(np.empty([0, 2, 2]))
self._verifyCholesky(np.empty([2, 0, 0]))
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testConcurrentExecutesWithoutError(self):
seed = [42, 24]
matrix_shape = [5, 5]
@ -196,108 +195,106 @@ class CholeskyOpTest(test.TestCase):
class CholeskyGradTest(test.TestCase):
_backprop_block_size = 32
_backprop_block_size = 16
def getShapes(self, shapeList):
return ((elem, int(np.floor(1.2 * elem))) for elem in shapeList)
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testSmallMatrices(self):
np.random.seed(0)
shapes = self.getShapes([1, 2, 10])
self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.float32, dtypes_lib.float64))
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testSmallMatricesComplex(self):
np.random.seed(0)
shapes = self.getShapes([1, 2, 10])
self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.complex64, dtypes_lib.complex128))
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testOneBlockMatrices(self):
np.random.seed(0)
shapes = self.getShapes([self._backprop_block_size + 1])
self.runFiniteDifferences(
shapes,
dtypes=(dtypes_lib.float32, dtypes_lib.float64),
scalarTest=True)
scalar_test=True)
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testTwoBlockMatrixFloat(self):
np.random.seed(0)
shapes = self.getShapes([2 * self._backprop_block_size + 1])
self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.float32,), scalarTest=True)
shapes, dtypes=(dtypes_lib.float32,), scalar_test=True)
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testTwoBlockMatrixDouble(self):
np.random.seed(0)
shapes = self.getShapes([2 * self._backprop_block_size + 1])
self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.float64,), scalarTest=True)
shapes, dtypes=(dtypes_lib.float64,), scalar_test=True)
@test_util.run_v1_only("b/120545219")
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testTwoBlockMatrixComplexFloat(self):
np.random.seed(0)
shapes = self.getShapes([2 * self._backprop_block_size + 1])
self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.complex64,), scalarTest=True)
shapes, dtypes=(dtypes_lib.complex64,), scalar_test=True)
@test_util.run_deprecated_v1
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testTwoBlockMatrixComplexDouble(self):
np.random.seed(0)
shapes = self.getShapes([2 * self._backprop_block_size + 1])
self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.complex128,), scalarTest=True)
shapes, dtypes=(dtypes_lib.complex128,), scalar_test=True)
def _runOneTest(self, shape, dtype, batch, scalar_test):
if dtype == dtypes_lib.float64:
tol = 1e-5
elif dtype == dtypes_lib.complex128:
tol = 5e-5
else:
tol = 5e-3
epsilon = np.finfo(dtype.as_numpy_dtype).eps
delta = epsilon**(1.0 / 3.0)
def RandomInput():
a = np.random.randn(shape[0], shape[1]).astype(dtype.as_numpy_dtype)
if dtype.is_complex:
a += 1j * np.random.randn(shape[0], shape[1]).astype(
dtype.as_numpy_dtype)
return a
def Compute(x):
# Turn the random matrix x into a Hermitian matrix by
# computing the quadratic form x * x^H.
a = math_ops.matmul(x, math_ops.conj(
array_ops.matrix_transpose(x))) / shape[0]
if batch:
a = array_ops.tile(array_ops.expand_dims(a, 0), [2, 1, 1])
# Finally take the cholesky decomposition of the Hermitian matrix.
c = linalg_ops.cholesky(a)
if scalar_test:
# Reduce to a single scalar output to speed up test.
c = math_ops.reduce_mean(c)
return c
theoretical, numerical = gradient_checker_v2.compute_gradient(
Compute, [RandomInput()], delta=delta)
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
def runFiniteDifferences(self,
shapes,
dtypes=(dtypes_lib.float32, dtypes_lib.float64,
dtypes_lib.complex64, dtypes_lib.complex128),
scalarTest=False):
with self.session(use_gpu=True):
for shape in shapes:
for batch in False, True:
for dtype in dtypes:
if not scalarTest:
data = np.random.randn(shape[0], shape[1])
if dtype.is_complex:
data = data.astype(np.complex64)
data += 1j * np.random.randn(shape[0], shape[1])
x = constant_op.constant(data, dtype)
tensor = math_ops.matmul(
x, math_ops.conj(array_ops.transpose(x))) / shape[0]
else:
# This is designed to be a faster test for larger matrices.
data = np.random.randn()
if dtype.is_complex:
data = np.complex64(data)
data += 1j * np.random.randn()
x = constant_op.constant(data, dtype)
R = constant_op.constant(
np.random.randn(shape[0], shape[1]), dtype)
e = math_ops.multiply(R, x)
tensor = math_ops.matmul(
e, math_ops.conj(array_ops.transpose(e))) / shape[0]
# Inner-most matrices in tensor are positive definite.
if batch:
tensor = array_ops.tile(
array_ops.expand_dims(tensor, 0), [4, 1, 1])
y = linalg_ops.cholesky(tensor)
if scalarTest:
y = math_ops.reduce_mean(y)
error = gradient_checker.compute_gradient_error(
x, x._shape_as_list(), y, y._shape_as_list())
tf_logging.info("error = %f", error)
if dtype == dtypes_lib.float64:
self.assertLess(error, 1e-5)
elif dtype == dtypes_lib.complex128:
self.assertLess(error, 5e-5)
else:
self.assertLess(error, 5e-3)
scalar_test=False):
for shape_ in shapes:
for dtype_ in dtypes:
for batch_ in False, True:
self._runOneTest(shape_, dtype_, batch_, scalar_test)
class CholeskyBenchmark(test.Benchmark):