Enable tests for tf.linalg.matrix_logarithm in eager mode.
PiperOrigin-RevId: 312757336 Change-Id: I0323132c43830f37bbb2480be700d6c2bc65f175
This commit is contained in:
parent
e312350702
commit
7221ad6eda
@ -23,12 +23,13 @@ import numpy as np
|
|||||||
from tensorflow.python.client import session
|
from tensorflow.python.client import session
|
||||||
from tensorflow.python.framework import constant_op
|
from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import gen_linalg_ops
|
from tensorflow.python.ops import gen_linalg_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import stateless_random_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.ops.linalg import linalg_impl
|
from tensorflow.python.ops.linalg import linalg_impl
|
||||||
from tensorflow.python.platform import benchmark
|
from tensorflow.python.platform import benchmark
|
||||||
@ -57,7 +58,7 @@ class LogarithmOpTest(test.TestCase):
|
|||||||
matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
|
matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
|
||||||
return matrix_batch
|
return matrix_batch
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testNonsymmetric(self):
|
def testNonsymmetric(self):
|
||||||
# 2x2 matrices
|
# 2x2 matrices
|
||||||
matrix1 = np.array([[1., 2.], [3., 4.]])
|
matrix1 = np.array([[1., 2.], [3., 4.]])
|
||||||
@ -71,7 +72,7 @@ class LogarithmOpTest(test.TestCase):
|
|||||||
# Complex batch
|
# Complex batch
|
||||||
self._verifyLogarithmComplex(self._makeBatch(matrix1, matrix2))
|
self._verifyLogarithmComplex(self._makeBatch(matrix1, matrix2))
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testSymmetricPositiveDefinite(self):
|
def testSymmetricPositiveDefinite(self):
|
||||||
# 2x2 matrices
|
# 2x2 matrices
|
||||||
matrix1 = np.array([[2., 1.], [1., 2.]])
|
matrix1 = np.array([[2., 1.], [1., 2.]])
|
||||||
@ -85,27 +86,27 @@ class LogarithmOpTest(test.TestCase):
|
|||||||
# Complex batch
|
# Complex batch
|
||||||
self._verifyLogarithmComplex(self._makeBatch(matrix1, matrix2))
|
self._verifyLogarithmComplex(self._makeBatch(matrix1, matrix2))
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testNonSquareMatrix(self):
|
def testNonSquareMatrix(self):
|
||||||
# When the logarithm of a non-square matrix is attempted we should return
|
# When the logarithm of a non-square matrix is attempted we should return
|
||||||
# an error
|
# an error
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||||
gen_linalg_ops.matrix_logarithm(
|
gen_linalg_ops.matrix_logarithm(
|
||||||
np.array([[1., 2., 3.], [3., 4., 5.]], dtype=np.complex64))
|
np.array([[1., 2., 3.], [3., 4., 5.]], dtype=np.complex64))
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testWrongDimensions(self):
|
def testWrongDimensions(self):
|
||||||
# The input to the logarithm should be at least a 2-dimensional tensor.
|
# The input to the logarithm should be at least a 2-dimensional tensor.
|
||||||
tensor3 = constant_op.constant([1., 2.], dtype=dtypes.complex64)
|
tensor3 = constant_op.constant([1., 2.], dtype=dtypes.complex64)
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises((ValueError, errors_impl.InvalidArgumentError)):
|
||||||
gen_linalg_ops.matrix_logarithm(tensor3)
|
gen_linalg_ops.matrix_logarithm(tensor3)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testEmpty(self):
|
def testEmpty(self):
|
||||||
self._verifyLogarithmComplex(np.empty([0, 2, 2], dtype=np.complex64))
|
self._verifyLogarithmComplex(np.empty([0, 2, 2], dtype=np.complex64))
|
||||||
self._verifyLogarithmComplex(np.empty([2, 0, 0], dtype=np.complex64))
|
self._verifyLogarithmComplex(np.empty([2, 0, 0], dtype=np.complex64))
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testRandomSmallAndLargeComplex64(self):
|
def testRandomSmallAndLargeComplex64(self):
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
for batch_dims in [(), (1,), (3,), (2, 2)]:
|
for batch_dims in [(), (1,), (3,), (2, 2)]:
|
||||||
@ -116,7 +117,7 @@ class LogarithmOpTest(test.TestCase):
|
|||||||
size=np.prod(shape)).reshape(shape).astype(np.complex64)
|
size=np.prod(shape)).reshape(shape).astype(np.complex64)
|
||||||
self._verifyLogarithmComplex(matrix)
|
self._verifyLogarithmComplex(matrix)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testRandomSmallAndLargeComplex128(self):
|
def testRandomSmallAndLargeComplex128(self):
|
||||||
np.random.seed(42)
|
np.random.seed(42)
|
||||||
for batch_dims in [(), (1,), (3,), (2, 2)]:
|
for batch_dims in [(), (1,), (3,), (2, 2)]:
|
||||||
@ -127,13 +128,17 @@ class LogarithmOpTest(test.TestCase):
|
|||||||
size=np.prod(shape)).reshape(shape).astype(np.complex128)
|
size=np.prod(shape)).reshape(shape).astype(np.complex128)
|
||||||
self._verifyLogarithmComplex(matrix)
|
self._verifyLogarithmComplex(matrix)
|
||||||
|
|
||||||
@test_util.run_v1_only("b/120545219")
|
@test_util.run_in_graph_and_eager_modes(use_gpu=True)
|
||||||
def testConcurrentExecutesWithoutError(self):
|
def testConcurrentExecutesWithoutError(self):
|
||||||
with self.session(use_gpu=True) as sess:
|
matrix_shape = [5, 5]
|
||||||
|
seed = [42, 24]
|
||||||
matrix1 = math_ops.cast(
|
matrix1 = math_ops.cast(
|
||||||
random_ops.random_normal([5, 5], seed=42), dtypes.complex64)
|
stateless_random_ops.stateless_random_normal(matrix_shape, seed=seed),
|
||||||
|
dtypes.complex64)
|
||||||
matrix2 = math_ops.cast(
|
matrix2 = math_ops.cast(
|
||||||
random_ops.random_normal([5, 5], seed=42), dtypes.complex64)
|
stateless_random_ops.stateless_random_normal(matrix_shape, seed=seed),
|
||||||
|
dtypes.complex64)
|
||||||
|
self.assertAllEqual(matrix1, matrix2)
|
||||||
logm1 = gen_linalg_ops.matrix_logarithm(matrix1)
|
logm1 = gen_linalg_ops.matrix_logarithm(matrix1)
|
||||||
logm2 = gen_linalg_ops.matrix_logarithm(matrix2)
|
logm2 = gen_linalg_ops.matrix_logarithm(matrix2)
|
||||||
logm = self.evaluate([logm1, logm2])
|
logm = self.evaluate([logm1, logm2])
|
||||||
|
Loading…
Reference in New Issue
Block a user