diff --git a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py index c36d83e2530..6cf330ed981 100644 --- a/tensorflow/python/kernel_tests/matrix_square_root_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_square_root_op_test.py @@ -21,10 +21,11 @@ from __future__ import print_function import numpy as np from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl from tensorflow.python.framework import test_util from tensorflow.python.ops import gen_linalg_ops from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops +from tensorflow.python.ops import stateless_random_ops from tensorflow.python.platform import test @@ -89,31 +90,35 @@ class SquareRootOpTest(test.TestCase): self._verifySquareRootReal(np.empty([0, 2, 2])) self._verifySquareRootReal(np.empty([2, 0, 0])) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testWrongDimensions(self): # The input to the square root should be at least a 2-dimensional tensor. tensor = constant_op.constant([1., 2.]) - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): gen_linalg_ops.matrix_square_root(tensor) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testNotSquare(self): - with self.assertRaises(ValueError): + with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)): tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]]) self.evaluate(gen_linalg_ops.matrix_square_root(tensor)) - @test_util.run_v1_only("b/120545219") + @test_util.run_in_graph_and_eager_modes(use_gpu=True) def testConcurrentExecutesWithoutError(self): - with test_util.use_gpu(): - matrix1 = random_ops.random_normal([5, 5], seed=42) - matrix2 = random_ops.random_normal([5, 5], seed=42) - square1 = math_ops.matmul(matrix1, matrix1) - square2 = math_ops.matmul(matrix2, matrix2) - sqrt1 = gen_linalg_ops.matrix_square_root(square1) - sqrt2 = gen_linalg_ops.matrix_square_root(square2) - all_ops = [sqrt1, sqrt2] - sqrt = self.evaluate(all_ops) - self.assertAllClose(sqrt[0], sqrt[1]) + matrix_shape = [5, 5] + seed = [42, 24] + matrix1 = stateless_random_ops.stateless_random_normal( + shape=matrix_shape, seed=seed) + matrix2 = stateless_random_ops.stateless_random_normal( + shape=matrix_shape, seed=seed) + self.assertAllEqual(matrix1, matrix2) + square1 = math_ops.matmul(matrix1, matrix1) + square2 = math_ops.matmul(matrix2, matrix2) + sqrt1 = gen_linalg_ops.matrix_square_root(square1) + sqrt2 = gen_linalg_ops.matrix_square_root(square2) + all_ops = [sqrt1, sqrt2] + sqrt = self.evaluate(all_ops) + self.assertAllClose(sqrt[0], sqrt[1]) if __name__ == "__main__":