Enable tests for tf.linalg.matrix_square_root in eager mode.

PiperOrigin-RevId: 312133318
Change-Id: I541a94a21594384fba30a9198ad5a7300537c498
This commit is contained in:
A. Unique TensorFlower 2020-05-18 12:37:47 -07:00 committed by TensorFlower Gardener
parent 672e419c9f
commit 7254343a10

View File

@ -21,10 +21,11 @@ from __future__ import print_function
import numpy as np import numpy as np
from tensorflow.python.framework import constant_op from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import stateless_random_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -89,31 +90,35 @@ class SquareRootOpTest(test.TestCase):
self._verifySquareRootReal(np.empty([0, 2, 2])) self._verifySquareRootReal(np.empty([0, 2, 2]))
self._verifySquareRootReal(np.empty([2, 0, 0])) self._verifySquareRootReal(np.empty([2, 0, 0]))
@test_util.run_v1_only("b/120545219") @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testWrongDimensions(self): def testWrongDimensions(self):
# The input to the square root should be at least a 2-dimensional tensor. # The input to the square root should be at least a 2-dimensional tensor.
tensor = constant_op.constant([1., 2.]) tensor = constant_op.constant([1., 2.])
with self.assertRaises(ValueError): with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
gen_linalg_ops.matrix_square_root(tensor) gen_linalg_ops.matrix_square_root(tensor)
@test_util.run_v1_only("b/120545219") @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testNotSquare(self): def testNotSquare(self):
with self.assertRaises(ValueError): with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]]) tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]])
self.evaluate(gen_linalg_ops.matrix_square_root(tensor)) self.evaluate(gen_linalg_ops.matrix_square_root(tensor))
@test_util.run_v1_only("b/120545219") @test_util.run_in_graph_and_eager_modes(use_gpu=True)
def testConcurrentExecutesWithoutError(self): def testConcurrentExecutesWithoutError(self):
with test_util.use_gpu(): matrix_shape = [5, 5]
matrix1 = random_ops.random_normal([5, 5], seed=42) seed = [42, 24]
matrix2 = random_ops.random_normal([5, 5], seed=42) matrix1 = stateless_random_ops.stateless_random_normal(
square1 = math_ops.matmul(matrix1, matrix1) shape=matrix_shape, seed=seed)
square2 = math_ops.matmul(matrix2, matrix2) matrix2 = stateless_random_ops.stateless_random_normal(
sqrt1 = gen_linalg_ops.matrix_square_root(square1) shape=matrix_shape, seed=seed)
sqrt2 = gen_linalg_ops.matrix_square_root(square2) self.assertAllEqual(matrix1, matrix2)
all_ops = [sqrt1, sqrt2] square1 = math_ops.matmul(matrix1, matrix1)
sqrt = self.evaluate(all_ops) square2 = math_ops.matmul(matrix2, matrix2)
self.assertAllClose(sqrt[0], sqrt[1]) sqrt1 = gen_linalg_ops.matrix_square_root(square1)
sqrt2 = gen_linalg_ops.matrix_square_root(square2)
all_ops = [sqrt1, sqrt2]
sqrt = self.evaluate(all_ops)
self.assertAllClose(sqrt[0], sqrt[1])
if __name__ == "__main__": if __name__ == "__main__":