Enable tests for tf.linalg.matrix_square_root in eager mode.
PiperOrigin-RevId: 312133318 Change-Id: I541a94a21594384fba30a9198ad5a7300537c498
This commit is contained in:
parent
672e419c9f
commit
7254343a10
@ -21,10 +21,11 @@ from __future__ import print_function
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import gen_linalg_ops
|
from tensorflow.python.ops import gen_linalg_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import stateless_random_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
@ -89,24 +90,28 @@ class SquareRootOpTest(test.TestCase):
|
|||||||
self._verifySquareRootReal(np.empty([0, 2, 2]))
|
self._verifySquareRootReal(np.empty([0, 2, 2]))
|
||||||
self._verifySquareRootReal(np.empty([2, 0, 0]))
|
self._verifySquareRootReal(np.empty([2, 0, 0]))
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testWrongDimensions(self):
|
def testWrongDimensions(self):
|
||||||
# The input to the square root should be at least a 2-dimensional tensor.
|
# The input to the square root should be at least a 2-dimensional tensor.
|
||||||
tensor = constant_op.constant([1., 2.])
|
tensor = constant_op.constant([1., 2.])
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||||
gen_linalg_ops.matrix_square_root(tensor)
|
gen_linalg_ops.matrix_square_root(tensor)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testNotSquare(self):
|
def testNotSquare(self):
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||||
tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]])
|
tensor = constant_op.constant([[1., 0., -1.], [-1., 1., 0.]])
|
||||||
self.evaluate(gen_linalg_ops.matrix_square_root(tensor))
|
self.evaluate(gen_linalg_ops.matrix_square_root(tensor))
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testConcurrentExecutesWithoutError(self):
|
def testConcurrentExecutesWithoutError(self):
|
||||||
with test_util.use_gpu():
|
matrix_shape = [5, 5]
|
||||||
matrix1 = random_ops.random_normal([5, 5], seed=42)
|
seed = [42, 24]
|
||||||
matrix2 = random_ops.random_normal([5, 5], seed=42)
|
matrix1 = stateless_random_ops.stateless_random_normal(
|
||||||
|
shape=matrix_shape, seed=seed)
|
||||||
|
matrix2 = stateless_random_ops.stateless_random_normal(
|
||||||
|
shape=matrix_shape, seed=seed)
|
||||||
|
self.assertAllEqual(matrix1, matrix2)
|
||||||
square1 = math_ops.matmul(matrix1, matrix1)
|
square1 = math_ops.matmul(matrix1, matrix1)
|
||||||
square2 = math_ops.matmul(matrix2, matrix2)
|
square2 = math_ops.matmul(matrix2, matrix2)
|
||||||
sqrt1 = gen_linalg_ops.matrix_square_root(square1)
|
sqrt1 = gen_linalg_ops.matrix_square_root(square1)
|
||||||
|
Loading…
Reference in New Issue
Block a user