Enable gradient tests for tf.linalg.cholesky in eager mode.

PiperOrigin-RevId: 312723423
Change-Id: I47d52dc14638301504ef8eccf481c7d7e3a60f48
This commit is contained in:
A. Unique TensorFlower 2020-05-21 12:53:51 -07:00 committed by TensorFlower Gardener
parent 97528c3175
commit 8fdb54ea98

View File

@ -29,7 +29,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import gradient_checker_v2
from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import stateless_random_ops from tensorflow.python.ops import stateless_random_ops
@ -37,7 +37,6 @@ from tensorflow.python.ops import variables
from tensorflow.python.ops.linalg import linalg from tensorflow.python.ops.linalg import linalg
from tensorflow.python.platform import benchmark from tensorflow.python.platform import benchmark
from tensorflow.python.platform import test from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
# Different gradient implementations for benchmark purposes # Different gradient implementations for benchmark purposes
@ -181,7 +180,7 @@ class CholeskyOpTest(test.TestCase):
self._verifyCholesky(np.empty([0, 2, 2])) self._verifyCholesky(np.empty([0, 2, 2]))
self._verifyCholesky(np.empty([2, 0, 0])) self._verifyCholesky(np.empty([2, 0, 0]))
@test_util.run_deprecated_v1 @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testConcurrentExecutesWithoutError(self): def testConcurrentExecutesWithoutError(self):
seed = [42, 24] seed = [42, 24]
matrix_shape = [5, 5] matrix_shape = [5, 5]
@ -196,108 +195,106 @@ class CholeskyOpTest(test.TestCase):
class CholeskyGradTest(test.TestCase): class CholeskyGradTest(test.TestCase):
_backprop_block_size = 32 _backprop_block_size = 16
def getShapes(self, shapeList): def getShapes(self, shapeList):
return ((elem, int(np.floor(1.2 * elem))) for elem in shapeList) return ((elem, int(np.floor(1.2 * elem))) for elem in shapeList)
@test_util.run_deprecated_v1 @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testSmallMatrices(self): def testSmallMatrices(self):
np.random.seed(0) np.random.seed(0)
shapes = self.getShapes([1, 2, 10]) shapes = self.getShapes([1, 2, 10])
self.runFiniteDifferences( self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.float32, dtypes_lib.float64)) shapes, dtypes=(dtypes_lib.float32, dtypes_lib.float64))
@test_util.run_deprecated_v1 @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testSmallMatricesComplex(self): def testSmallMatricesComplex(self):
np.random.seed(0) np.random.seed(0)
shapes = self.getShapes([1, 2, 10]) shapes = self.getShapes([1, 2, 10])
self.runFiniteDifferences( self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.complex64, dtypes_lib.complex128)) shapes, dtypes=(dtypes_lib.complex64, dtypes_lib.complex128))
@test_util.run_deprecated_v1 @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testOneBlockMatrices(self): def testOneBlockMatrices(self):
np.random.seed(0) np.random.seed(0)
shapes = self.getShapes([self._backprop_block_size + 1]) shapes = self.getShapes([self._backprop_block_size + 1])
self.runFiniteDifferences( self.runFiniteDifferences(
shapes, shapes,
dtypes=(dtypes_lib.float32, dtypes_lib.float64), dtypes=(dtypes_lib.float32, dtypes_lib.float64),
scalarTest=True) scalar_test=True)
@test_util.run_deprecated_v1 @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testTwoBlockMatrixFloat(self): def testTwoBlockMatrixFloat(self):
np.random.seed(0) np.random.seed(0)
shapes = self.getShapes([2 * self._backprop_block_size + 1]) shapes = self.getShapes([2 * self._backprop_block_size + 1])
self.runFiniteDifferences( self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.float32,), scalarTest=True) shapes, dtypes=(dtypes_lib.float32,), scalar_test=True)
@test_util.run_deprecated_v1 @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testTwoBlockMatrixDouble(self): def testTwoBlockMatrixDouble(self):
np.random.seed(0) np.random.seed(0)
shapes = self.getShapes([2 * self._backprop_block_size + 1]) shapes = self.getShapes([2 * self._backprop_block_size + 1])
self.runFiniteDifferences( self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.float64,), scalarTest=True) shapes, dtypes=(dtypes_lib.float64,), scalar_test=True)
@test_util.run_v1_only("b/120545219") @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testTwoBlockMatrixComplexFloat(self): def testTwoBlockMatrixComplexFloat(self):
np.random.seed(0) np.random.seed(0)
shapes = self.getShapes([2 * self._backprop_block_size + 1]) shapes = self.getShapes([2 * self._backprop_block_size + 1])
self.runFiniteDifferences( self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.complex64,), scalarTest=True) shapes, dtypes=(dtypes_lib.complex64,), scalar_test=True)
@test_util.run_deprecated_v1 @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testTwoBlockMatrixComplexDouble(self): def testTwoBlockMatrixComplexDouble(self):
np.random.seed(0) np.random.seed(0)
shapes = self.getShapes([2 * self._backprop_block_size + 1]) shapes = self.getShapes([2 * self._backprop_block_size + 1])
self.runFiniteDifferences( self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.complex128,), scalarTest=True) shapes, dtypes=(dtypes_lib.complex128,), scalar_test=True)
def _runOneTest(self, shape, dtype, batch, scalar_test):
if dtype == dtypes_lib.float64:
tol = 1e-5
elif dtype == dtypes_lib.complex128:
tol = 5e-5
else:
tol = 5e-3
epsilon = np.finfo(dtype.as_numpy_dtype).eps
delta = epsilon**(1.0 / 3.0)
def RandomInput():
a = np.random.randn(shape[0], shape[1]).astype(dtype.as_numpy_dtype)
if dtype.is_complex:
a += 1j * np.random.randn(shape[0], shape[1]).astype(
dtype.as_numpy_dtype)
return a
def Compute(x):
# Turn the random matrix x into a Hermitian matrix by
# computing the quadratic form x * x^H.
a = math_ops.matmul(x, math_ops.conj(
array_ops.matrix_transpose(x))) / shape[0]
if batch:
a = array_ops.tile(array_ops.expand_dims(a, 0), [2, 1, 1])
# Finally take the cholesky decomposition of the Hermitian matrix.
c = linalg_ops.cholesky(a)
if scalar_test:
# Reduce to a single scalar output to speed up test.
c = math_ops.reduce_mean(c)
return c
theoretical, numerical = gradient_checker_v2.compute_gradient(
Compute, [RandomInput()], delta=delta)
self.assertAllClose(theoretical, numerical, atol=tol, rtol=tol)
def runFiniteDifferences(self, def runFiniteDifferences(self,
shapes, shapes,
dtypes=(dtypes_lib.float32, dtypes_lib.float64, dtypes=(dtypes_lib.float32, dtypes_lib.float64,
dtypes_lib.complex64, dtypes_lib.complex128), dtypes_lib.complex64, dtypes_lib.complex128),
scalarTest=False): scalar_test=False):
with self.session(use_gpu=True): for shape_ in shapes:
for shape in shapes: for dtype_ in dtypes:
for batch in False, True: for batch_ in False, True:
for dtype in dtypes: self._runOneTest(shape_, dtype_, batch_, scalar_test)
if not scalarTest:
data = np.random.randn(shape[0], shape[1])
if dtype.is_complex:
data = data.astype(np.complex64)
data += 1j * np.random.randn(shape[0], shape[1])
x = constant_op.constant(data, dtype)
tensor = math_ops.matmul(
x, math_ops.conj(array_ops.transpose(x))) / shape[0]
else:
# This is designed to be a faster test for larger matrices.
data = np.random.randn()
if dtype.is_complex:
data = np.complex64(data)
data += 1j * np.random.randn()
x = constant_op.constant(data, dtype)
R = constant_op.constant(
np.random.randn(shape[0], shape[1]), dtype)
e = math_ops.multiply(R, x)
tensor = math_ops.matmul(
e, math_ops.conj(array_ops.transpose(e))) / shape[0]
# Inner-most matrices in tensor are positive definite.
if batch:
tensor = array_ops.tile(
array_ops.expand_dims(tensor, 0), [4, 1, 1])
y = linalg_ops.cholesky(tensor)
if scalarTest:
y = math_ops.reduce_mean(y)
error = gradient_checker.compute_gradient_error(
x, x._shape_as_list(), y, y._shape_as_list())
tf_logging.info("error = %f", error)
if dtype == dtypes_lib.float64:
self.assertLess(error, 1e-5)
elif dtype == dtypes_lib.complex128:
self.assertLess(error, 5e-5)
else:
self.assertLess(error, 5e-3)
class CholeskyBenchmark(test.Benchmark): class CholeskyBenchmark(test.Benchmark):