Make non-meta linear operators (other than Circulant/Toeplitz) tape safe.
PiperOrigin-RevId: 259453506
This commit is contained in:
parent
df6ba21e45
commit
95bcd434d0
@ -382,7 +382,7 @@ class WishartCholeskyTest(test.TestCase):
|
||||
with self.assertRaisesRegexp(ValueError, "cannot be less than"):
|
||||
distributions.WishartCholesky(
|
||||
df=2, scale=chol_scale, validate_args=False)
|
||||
with self.assertRaisesRegexp(TypeError, "Argument tril must have dtype"):
|
||||
with self.assertRaisesRegexp(TypeError, "."):
|
||||
distributions.WishartCholesky(
|
||||
df=4.,
|
||||
scale=np.asarray(
|
||||
|
@ -20,7 +20,9 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||
from tensorflow.python.ops.linalg import linear_operator_block_diag as block_diag
|
||||
from tensorflow.python.ops.linalg import linear_operator_lower_triangular as lower_triangular
|
||||
@ -56,6 +58,7 @@ def _block_diag_dense(expected_shape, blocks):
|
||||
return array_ops.concat(rows, axis=-2)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class SquareLinearOperatorBlockDiagTest(
|
||||
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
|
||||
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
|
||||
@ -209,6 +212,26 @@ class SquareLinearOperatorBlockDiagTest(
|
||||
block_diag.LinearOperatorBlockDiag)
|
||||
self.assertEqual(2, len(inverse.operators))
|
||||
|
||||
def test_tape_safe(self):
|
||||
matrix = variables_module.Variable([[1., 0.], [0., 1.]])
|
||||
operator = block_diag.LinearOperatorBlockDiag(
|
||||
[
|
||||
linalg.LinearOperatorFullMatrix(
|
||||
matrix,
|
||||
is_self_adjoint=True,
|
||||
is_positive_definite=True,
|
||||
),
|
||||
linalg.LinearOperatorFullMatrix(
|
||||
matrix,
|
||||
is_self_adjoint=True,
|
||||
is_positive_definite=True,
|
||||
),
|
||||
],
|
||||
is_self_adjoint=True,
|
||||
is_positive_definite=True,
|
||||
)
|
||||
self.check_tape_safe(operator)
|
||||
|
||||
def test_is_non_singular_auto_set(self):
|
||||
# Matrix with two positive eigenvalues, 11 and 8.
|
||||
# The matrix values do not effect auto-setting of the flags.
|
||||
|
@ -17,17 +17,21 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||
from tensorflow.python.ops.linalg import linear_operator_householder as householder
|
||||
from tensorflow.python.ops.linalg import linear_operator_test_util
|
||||
from tensorflow.python.platform import test
|
||||
|
||||
linalg = linalg_lib
|
||||
CheckTapeSafeSkipOptions = linear_operator_test_util.CheckTapeSafeSkipOptions
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LinearOperatorHouseholderTest(
|
||||
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
|
||||
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
|
||||
@ -87,6 +91,19 @@ class LinearOperatorHouseholderTest(
|
||||
self.assertIsInstance(
|
||||
operator.inverse(), householder.LinearOperatorHouseholder)
|
||||
|
||||
def test_tape_safe(self):
|
||||
reflection_axis = variables_module.Variable([1., 3., 5., 8.])
|
||||
operator = householder.LinearOperatorHouseholder(reflection_axis)
|
||||
self.check_tape_safe(
|
||||
operator,
|
||||
skip_options=[
|
||||
# Determinant hard-coded as 1.
|
||||
CheckTapeSafeSkipOptions.DETERMINANT,
|
||||
CheckTapeSafeSkipOptions.LOG_ABS_DETERMINANT,
|
||||
# Trace hard-coded.
|
||||
CheckTapeSafeSkipOptions.TRACE,
|
||||
])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
linear_operator_test_util.add_tests(LinearOperatorHouseholderTest)
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||
from tensorflow.python.ops.linalg import linear_operator_test_util
|
||||
from tensorflow.python.platform import test
|
||||
@ -33,6 +34,7 @@ from tensorflow.python.platform import test
|
||||
rng = np.random.RandomState(2016)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LinearOperatorIdentityTest(
|
||||
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
|
||||
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
|
||||
@ -61,23 +63,20 @@ class LinearOperatorIdentityTest(
|
||||
|
||||
return operator, mat
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_assert_positive_definite(self):
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
|
||||
operator.assert_positive_definite().run() # Should not fail
|
||||
self.evaluate(operator.assert_positive_definite()) # Should not fail
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_assert_non_singular(self):
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
|
||||
operator.assert_non_singular().run() # Should not fail
|
||||
self.evaluate(operator.assert_non_singular()) # Should not fail
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_assert_self_adjoint(self):
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
|
||||
operator.assert_self_adjoint().run() # Should not fail
|
||||
self.evaluate(operator.assert_self_adjoint()) # Should not fail
|
||||
|
||||
def test_float16_matmul(self):
|
||||
# float16 cannot be tested by base test class because tf.linalg.solve does
|
||||
@ -113,41 +112,38 @@ class LinearOperatorIdentityTest(
|
||||
with self.assertRaisesRegexp(ValueError, "must be non-negative"):
|
||||
linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=[-2])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_non_scalar_num_rows_raises_dynamic(self):
|
||||
with self.cached_session():
|
||||
num_rows = array_ops.placeholder(dtypes.int32)
|
||||
operator = linalg_lib.LinearOperatorIdentity(
|
||||
num_rows, assert_proper_shapes=True)
|
||||
with self.assertRaisesOpError("must be a 0-D Tensor"):
|
||||
operator.to_dense().eval(feed_dict={num_rows: [2]})
|
||||
num_rows = array_ops.placeholder_with_default([2], shape=None)
|
||||
|
||||
with self.assertRaisesError("must be a 0-D Tensor"):
|
||||
operator = linalg_lib.LinearOperatorIdentity(
|
||||
num_rows, assert_proper_shapes=True)
|
||||
self.evaluate(operator.to_dense())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_negative_num_rows_raises_dynamic(self):
|
||||
with self.cached_session():
|
||||
num_rows = array_ops.placeholder(dtypes.int32)
|
||||
operator = linalg_lib.LinearOperatorIdentity(
|
||||
num_rows, assert_proper_shapes=True)
|
||||
with self.assertRaisesOpError("must be non-negative"):
|
||||
operator.to_dense().eval(feed_dict={num_rows: -2})
|
||||
num_rows = array_ops.placeholder_with_default(-2, shape=None)
|
||||
with self.assertRaisesError("must be non-negative"):
|
||||
operator = linalg_lib.LinearOperatorIdentity(
|
||||
num_rows, assert_proper_shapes=True)
|
||||
self.evaluate(operator.to_dense())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_non_1d_batch_shape_raises_dynamic(self):
|
||||
with self.cached_session():
|
||||
batch_shape = array_ops.placeholder(dtypes.int32)
|
||||
operator = linalg_lib.LinearOperatorIdentity(
|
||||
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
|
||||
with self.assertRaisesOpError("must be a 1-D"):
|
||||
operator.to_dense().eval(feed_dict={batch_shape: 2})
|
||||
batch_shape = array_ops.placeholder_with_default(2, shape=None)
|
||||
with self.assertRaisesError("must be a 1-D"):
|
||||
operator = linalg_lib.LinearOperatorIdentity(
|
||||
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
|
||||
self.evaluate(operator.to_dense())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_negative_batch_shape_raises_dynamic(self):
|
||||
with self.cached_session():
|
||||
batch_shape = array_ops.placeholder(dtypes.int32)
|
||||
operator = linalg_lib.LinearOperatorIdentity(
|
||||
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
|
||||
with self.assertRaisesOpError("must be non-negative"):
|
||||
operator.to_dense().eval(feed_dict={batch_shape: [-2]})
|
||||
batch_shape = array_ops.placeholder_with_default([-2], shape=None)
|
||||
with self.assertRaisesError("must be non-negative"):
|
||||
operator = linalg_lib.LinearOperatorIdentity(
|
||||
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
|
||||
self.evaluate(operator.to_dense())
|
||||
|
||||
def test_wrong_matrix_dimensions_raises_static(self):
|
||||
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
|
||||
@ -155,17 +151,16 @@ class LinearOperatorIdentityTest(
|
||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||
operator.matmul(x)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_wrong_matrix_dimensions_raises_dynamic(self):
|
||||
num_rows = array_ops.placeholder(dtypes.int32)
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
num_rows = array_ops.placeholder_with_default(2, shape=None)
|
||||
x = array_ops.placeholder_with_default(
|
||||
rng.rand(3, 3).astype(np.float32), shape=None)
|
||||
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorIdentity(
|
||||
num_rows, assert_proper_shapes=True)
|
||||
y = operator.matmul(x)
|
||||
with self.assertRaisesOpError("Incompatible.*dimensions"):
|
||||
y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)})
|
||||
with self.assertRaisesError("Dimensions.*not.compatible"):
|
||||
operator = linalg_lib.LinearOperatorIdentity(
|
||||
num_rows, assert_proper_shapes=True)
|
||||
self.evaluate(operator.matmul(x))
|
||||
|
||||
def test_default_batch_shape_broadcasts_with_everything_static(self):
|
||||
# These cannot be done in the automated (base test class) tests since they
|
||||
@ -181,22 +176,18 @@ class LinearOperatorIdentityTest(
|
||||
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
|
||||
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_default_batch_shape_broadcasts_with_everything_dynamic(self):
|
||||
# These cannot be done in the automated (base test class) tests since they
|
||||
# test shapes that tf.batch_matmul cannot handle.
|
||||
# In particular, tf.batch_matmul does not broadcast.
|
||||
with self.cached_session() as sess:
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
with self.cached_session():
|
||||
x = array_ops.placeholder_with_default(rng.randn(1, 2, 3, 4), shape=None)
|
||||
operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype)
|
||||
|
||||
operator_matmul = operator.matmul(x)
|
||||
expected = x
|
||||
|
||||
feed_dict = {x: rng.randn(1, 2, 3, 4)}
|
||||
|
||||
self.assertAllClose(
|
||||
*sess.run([operator_matmul, expected], feed_dict=feed_dict))
|
||||
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
|
||||
|
||||
def test_broadcast_matmul_static_shapes(self):
|
||||
# These cannot be done in the automated (base test class) tests since they
|
||||
@ -219,21 +210,19 @@ class LinearOperatorIdentityTest(
|
||||
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
|
||||
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_broadcast_matmul_dynamic_shapes(self):
|
||||
# These cannot be done in the automated (base test class) tests since they
|
||||
# test shapes that tf.batch_matmul cannot handle.
|
||||
# In particular, tf.batch_matmul does not broadcast.
|
||||
with self.cached_session() as sess:
|
||||
with self.cached_session():
|
||||
# Given this x and LinearOperatorIdentity shape of (2, 1, 3, 3), the
|
||||
# broadcast shape of operator and 'x' is (2, 2, 3, 4)
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
num_rows = array_ops.placeholder(dtypes.int32)
|
||||
batch_shape = array_ops.placeholder(dtypes.int32)
|
||||
x = array_ops.placeholder_with_default(rng.rand(1, 2, 3, 4), shape=None)
|
||||
num_rows = array_ops.placeholder_with_default(3, shape=None)
|
||||
batch_shape = array_ops.placeholder_with_default((2, 1), shape=None)
|
||||
|
||||
operator = linalg_lib.LinearOperatorIdentity(
|
||||
num_rows, batch_shape=batch_shape)
|
||||
feed_dict = {x: rng.rand(1, 2, 3, 4), num_rows: 3, batch_shape: (2, 1)}
|
||||
num_rows, batch_shape=batch_shape, dtype=dtypes.float64)
|
||||
|
||||
# Batch matrix of zeros with the broadcast shape of x and operator.
|
||||
zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype)
|
||||
@ -242,8 +231,7 @@ class LinearOperatorIdentityTest(
|
||||
expected = x + zeros
|
||||
|
||||
operator_matmul = operator.matmul(x)
|
||||
self.assertAllClose(
|
||||
*sess.run([operator_matmul, expected], feed_dict=feed_dict))
|
||||
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
|
||||
|
||||
def test_is_x_flags(self):
|
||||
# The is_x flags are by default all True.
|
||||
@ -280,7 +268,16 @@ class LinearOperatorIdentityTest(
|
||||
self.assertIsInstance(
|
||||
operator.inverse(), linalg_lib.LinearOperatorIdentity)
|
||||
|
||||
def test_ref_type_shape_args_raises(self):
|
||||
with self.assertRaisesRegexp(TypeError, "num_rows.*reference"):
|
||||
linalg_lib.LinearOperatorIdentity(num_rows=variables_module.Variable(2))
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "batch_shape.*reference"):
|
||||
linalg_lib.LinearOperatorIdentity(
|
||||
num_rows=2, batch_shape=variables_module.Variable([3]))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LinearOperatorScaledIdentityTest(
|
||||
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
|
||||
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
|
||||
@ -331,47 +328,44 @@ class LinearOperatorScaledIdentityTest(
|
||||
|
||||
return operator, matrix
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_assert_positive_definite_does_not_raise_when_positive(self):
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=1.)
|
||||
operator.assert_positive_definite().run() # Should not fail
|
||||
self.evaluate(operator.assert_positive_definite()) # Should not fail
|
||||
|
||||
def test_assert_positive_definite_raises_when_negative(self):
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=-1.)
|
||||
with self.assertRaisesOpError("not positive definite"):
|
||||
operator.assert_positive_definite().run()
|
||||
self.evaluate(operator.assert_positive_definite())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_assert_non_singular_does_not_raise_when_non_singular(self):
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=[1., 2., 3.])
|
||||
operator.assert_non_singular().run() # Should not fail
|
||||
self.evaluate(operator.assert_non_singular()) # Should not fail
|
||||
|
||||
def test_assert_non_singular_raises_when_singular(self):
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=[1., 2., 0.])
|
||||
with self.assertRaisesOpError("was singular"):
|
||||
operator.assert_non_singular().run()
|
||||
self.evaluate(operator.assert_non_singular())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_assert_self_adjoint_does_not_raise_when_self_adjoint(self):
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=[1. + 0J])
|
||||
operator.assert_self_adjoint().run() # Should not fail
|
||||
self.evaluate(operator.assert_self_adjoint()) # Should not fail
|
||||
|
||||
def test_assert_self_adjoint_raises_when_not_self_adjoint(self):
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=[1. + 1J])
|
||||
with self.assertRaisesOpError("not self-adjoint"):
|
||||
operator.assert_self_adjoint().run()
|
||||
self.evaluate(operator.assert_self_adjoint())
|
||||
|
||||
def test_float16_matmul(self):
|
||||
# float16 cannot be tested by base test class because tf.linalg.solve does
|
||||
@ -397,17 +391,18 @@ class LinearOperatorScaledIdentityTest(
|
||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||
operator.matmul(x)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_wrong_matrix_dimensions_raises_dynamic(self):
|
||||
num_rows = array_ops.placeholder(dtypes.int32)
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
num_rows = array_ops.placeholder_with_default(2, shape=None)
|
||||
x = array_ops.placeholder_with_default(
|
||||
rng.rand(3, 3).astype(np.float32), shape=None)
|
||||
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows, multiplier=[1., 2], assert_proper_shapes=True)
|
||||
y = operator.matmul(x)
|
||||
with self.assertRaisesOpError("Incompatible.*dimensions"):
|
||||
y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)})
|
||||
with self.assertRaisesError("Dimensions.*not.compatible"):
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows,
|
||||
multiplier=[1., 2],
|
||||
assert_proper_shapes=True)
|
||||
self.evaluate(operator.matmul(x))
|
||||
|
||||
def test_broadcast_matmul_and_solve(self):
|
||||
# These cannot be done in the automated (base test class) tests since they
|
||||
@ -530,6 +525,17 @@ class LinearOperatorScaledIdentityTest(
|
||||
operator.inverse(),
|
||||
linalg_lib.LinearOperatorScaledIdentity)
|
||||
|
||||
def test_ref_type_shape_args_raises(self):
|
||||
with self.assertRaisesRegexp(TypeError, "num_rows.*reference"):
|
||||
linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=variables_module.Variable(2), multiplier=1.23)
|
||||
|
||||
def test_tape_safe(self):
|
||||
multiplier = variables_module.Variable(1.23)
|
||||
operator = linalg_lib.LinearOperatorScaledIdentity(
|
||||
num_rows=2, multiplier=multiplier)
|
||||
self.check_tape_safe(operator)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
linear_operator_test_util.add_tests(LinearOperatorIdentityTest)
|
||||
|
@ -17,8 +17,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||
from tensorflow.python.ops.linalg import linear_operator_test_util
|
||||
from tensorflow.python.platform import test
|
||||
@ -26,6 +28,7 @@ from tensorflow.python.platform import test
|
||||
linalg = linalg_lib
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LinearOperatorLowerTriangularTest(
|
||||
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
|
||||
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
|
||||
@ -101,6 +104,12 @@ class LinearOperatorLowerTriangularTest(
|
||||
operator1.to_dense()),
|
||||
self.evaluate(operator_matmul.to_dense()))
|
||||
|
||||
def test_tape_safe(self):
|
||||
tril = variables_module.Variable([[1., 0.], [0., 1.]])
|
||||
operator = linalg_lib.LinearOperatorLowerTriangular(
|
||||
tril, is_non_singular=True)
|
||||
self.check_tape_safe(operator)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
linear_operator_test_util.add_tests(LinearOperatorLowerTriangularTest)
|
||||
|
@ -20,9 +20,7 @@ from __future__ import print_function
|
||||
from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -34,66 +32,62 @@ rng = np.random.RandomState(0)
|
||||
|
||||
class AssertZeroImagPartTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_real_tensor_doesnt_raise(self):
|
||||
x = ops.convert_to_tensor([0., 2, 3])
|
||||
with self.cached_session():
|
||||
# Should not raise.
|
||||
linear_operator_util.assert_zero_imag_part(x, message="ABC123").run()
|
||||
# Should not raise.
|
||||
self.evaluate(
|
||||
linear_operator_util.assert_zero_imag_part(x, message="ABC123"))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_complex_tensor_with_imag_zero_doesnt_raise(self):
|
||||
x = ops.convert_to_tensor([1., 0, 3])
|
||||
y = ops.convert_to_tensor([0., 0, 0])
|
||||
z = math_ops.complex(x, y)
|
||||
with self.cached_session():
|
||||
# Should not raise.
|
||||
linear_operator_util.assert_zero_imag_part(z, message="ABC123").run()
|
||||
# Should not raise.
|
||||
self.evaluate(
|
||||
linear_operator_util.assert_zero_imag_part(z, message="ABC123"))
|
||||
|
||||
def test_complex_tensor_with_nonzero_imag_raises(self):
|
||||
x = ops.convert_to_tensor([1., 2, 0])
|
||||
y = ops.convert_to_tensor([1., 2, 0])
|
||||
z = math_ops.complex(x, y)
|
||||
with self.cached_session():
|
||||
with self.assertRaisesOpError("ABC123"):
|
||||
linear_operator_util.assert_zero_imag_part(z, message="ABC123").run()
|
||||
with self.assertRaisesOpError("ABC123"):
|
||||
self.evaluate(
|
||||
linear_operator_util.assert_zero_imag_part(z, message="ABC123"))
|
||||
|
||||
|
||||
class AssertNoEntriesWithModulusZeroTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_nonzero_real_tensor_doesnt_raise(self):
|
||||
x = ops.convert_to_tensor([1., 2, 3])
|
||||
with self.cached_session():
|
||||
# Should not raise.
|
||||
linear_operator_util.assert_no_entries_with_modulus_zero(
|
||||
x, message="ABC123").run()
|
||||
# Should not raise.
|
||||
self.evaluate(
|
||||
linear_operator_util.assert_no_entries_with_modulus_zero(
|
||||
x, message="ABC123"))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_nonzero_complex_tensor_doesnt_raise(self):
|
||||
x = ops.convert_to_tensor([1., 0, 3])
|
||||
y = ops.convert_to_tensor([1., 2, 0])
|
||||
z = math_ops.complex(x, y)
|
||||
with self.cached_session():
|
||||
# Should not raise.
|
||||
linear_operator_util.assert_no_entries_with_modulus_zero(
|
||||
z, message="ABC123").run()
|
||||
# Should not raise.
|
||||
self.evaluate(
|
||||
linear_operator_util.assert_no_entries_with_modulus_zero(
|
||||
z, message="ABC123"))
|
||||
|
||||
def test_zero_real_tensor_raises(self):
|
||||
x = ops.convert_to_tensor([1., 0, 3])
|
||||
with self.cached_session():
|
||||
with self.assertRaisesOpError("ABC123"):
|
||||
linear_operator_util.assert_no_entries_with_modulus_zero(
|
||||
x, message="ABC123").run()
|
||||
with self.assertRaisesOpError("ABC123"):
|
||||
self.evaluate(
|
||||
linear_operator_util.assert_no_entries_with_modulus_zero(
|
||||
x, message="ABC123"))
|
||||
|
||||
def test_zero_complex_tensor_raises(self):
|
||||
x = ops.convert_to_tensor([1., 2, 0])
|
||||
y = ops.convert_to_tensor([1., 2, 0])
|
||||
z = math_ops.complex(x, y)
|
||||
with self.cached_session():
|
||||
with self.assertRaisesOpError("ABC123"):
|
||||
linear_operator_util.assert_no_entries_with_modulus_zero(
|
||||
z, message="ABC123").run()
|
||||
with self.assertRaisesOpError("ABC123"):
|
||||
self.evaluate(
|
||||
linear_operator_util.assert_no_entries_with_modulus_zero(
|
||||
z, message="ABC123"))
|
||||
|
||||
|
||||
class BroadcastMatrixBatchDimsTest(test.TestCase):
|
||||
@ -107,10 +101,8 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
|
||||
tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr])
|
||||
self.assertTrue(isinstance(tensor, ops.Tensor))
|
||||
|
||||
with self.cached_session():
|
||||
self.assertAllClose(arr, self.evaluate(tensor))
|
||||
self.assertAllClose(arr, self.evaluate(tensor))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_static_dims_broadcast(self):
|
||||
# x.batch_shape = [3, 1, 2]
|
||||
# y.batch_shape = [4, 1]
|
||||
@ -123,12 +115,11 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
|
||||
|
||||
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
|
||||
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
|
||||
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
|
||||
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
|
||||
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
|
||||
def test_static_dims_broadcast_second_arg_higher_rank(self):
|
||||
# x.batch_shape = [1, 2]
|
||||
@ -142,14 +133,12 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
|
||||
|
||||
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
|
||||
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
|
||||
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
|
||||
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
|
||||
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_dynamic_dims_broadcast_32bit(self):
|
||||
# x.batch_shape = [3, 1, 2]
|
||||
# y.batch_shape = [4, 1]
|
||||
@ -160,17 +149,15 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
|
||||
x_bc_expected = x + batch_of_zeros
|
||||
y_bc_expected = y + batch_of_zeros
|
||||
|
||||
x_ph = array_ops.placeholder(dtypes.float32)
|
||||
y_ph = array_ops.placeholder(dtypes.float32)
|
||||
x_ph = array_ops.placeholder_with_default(x, shape=None)
|
||||
y_ph = array_ops.placeholder_with_default(y, shape=None)
|
||||
|
||||
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_dynamic_dims_broadcast_32bit_second_arg_higher_rank(self):
|
||||
# x.batch_shape = [1, 2]
|
||||
# y.batch_shape = [3, 4, 1]
|
||||
@ -181,15 +168,14 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
|
||||
x_bc_expected = x + batch_of_zeros
|
||||
y_bc_expected = y + batch_of_zeros
|
||||
|
||||
x_ph = array_ops.placeholder(dtypes.float32)
|
||||
y_ph = array_ops.placeholder(dtypes.float32)
|
||||
x_ph = array_ops.placeholder_with_default(x, shape=None)
|
||||
y_ph = array_ops.placeholder_with_default(y, shape=None)
|
||||
|
||||
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
|
||||
|
||||
with self.cached_session() as sess:
|
||||
x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
|
||||
self.assertAllClose(x_bc_expected, x_bc_)
|
||||
self.assertAllClose(y_bc_expected, y_bc_)
|
||||
|
||||
def test_less_than_two_dims_raises_static(self):
|
||||
x = rng.rand(3)
|
||||
@ -204,20 +190,17 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
|
||||
|
||||
class CholeskySolveWithBroadcastTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_static_dims_broadcast(self):
|
||||
# batch_shape = [2]
|
||||
chol = rng.rand(3, 3)
|
||||
rhs = rng.rand(2, 3, 7)
|
||||
chol_broadcast = chol + np.zeros((2, 1, 1))
|
||||
|
||||
with self.cached_session():
|
||||
result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs)
|
||||
self.assertAllEqual((2, 3, 7), result.get_shape())
|
||||
expected = linalg_ops.cholesky_solve(chol_broadcast, rhs)
|
||||
self.assertAllClose(expected.eval(), self.evaluate(result))
|
||||
result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs)
|
||||
self.assertAllEqual((2, 3, 7), result.get_shape())
|
||||
expected = linalg_ops.cholesky_solve(chol_broadcast, rhs)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_dynamic_dims_broadcast_64bit(self):
|
||||
# batch_shape = [2, 2]
|
||||
chol = rng.rand(2, 3, 3)
|
||||
@ -225,40 +208,29 @@ class CholeskySolveWithBroadcastTest(test.TestCase):
|
||||
chol_broadcast = chol + np.zeros((2, 2, 1, 1))
|
||||
rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))
|
||||
|
||||
chol_ph = array_ops.placeholder(dtypes.float64)
|
||||
rhs_ph = array_ops.placeholder(dtypes.float64)
|
||||
chol_ph = array_ops.placeholder_with_default(chol, shape=None)
|
||||
rhs_ph = array_ops.placeholder_with_default(rhs, shape=None)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result, expected = sess.run(
|
||||
[
|
||||
linear_operator_util.cholesky_solve_with_broadcast(
|
||||
chol_ph, rhs_ph),
|
||||
linalg_ops.cholesky_solve(chol_broadcast, rhs_broadcast)
|
||||
],
|
||||
feed_dict={
|
||||
chol_ph: chol,
|
||||
rhs_ph: rhs,
|
||||
})
|
||||
self.assertAllClose(expected, result)
|
||||
result, expected = self.evaluate([
|
||||
linear_operator_util.cholesky_solve_with_broadcast(chol_ph, rhs_ph),
|
||||
linalg_ops.cholesky_solve(chol_broadcast, rhs_broadcast)
|
||||
])
|
||||
self.assertAllClose(expected, result)
|
||||
|
||||
|
||||
class MatrixSolveWithBroadcastTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_static_dims_broadcast_matrix_has_extra_dims(self):
|
||||
# batch_shape = [2]
|
||||
matrix = rng.rand(2, 3, 3)
|
||||
rhs = rng.rand(3, 7)
|
||||
rhs_broadcast = rhs + np.zeros((2, 1, 1))
|
||||
|
||||
with self.cached_session():
|
||||
result = linear_operator_util.matrix_solve_with_broadcast(
|
||||
matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 7), result.get_shape())
|
||||
expected = linalg_ops.matrix_solve(matrix, rhs_broadcast)
|
||||
self.assertAllClose(expected.eval(), self.evaluate(result))
|
||||
result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 7), result.get_shape())
|
||||
expected = linalg_ops.matrix_solve(matrix, rhs_broadcast)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_static_dims_broadcast_rhs_has_extra_dims(self):
|
||||
# Since the second arg has extra dims, and the domain dim of the first arg
|
||||
# is larger than the number of linear equations, code will "flip" the extra
|
||||
@ -271,13 +243,11 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
|
||||
rhs = rng.rand(2, 3, 2)
|
||||
matrix_broadcast = matrix + np.zeros((2, 1, 1))
|
||||
|
||||
with self.cached_session():
|
||||
result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 2), result.get_shape())
|
||||
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
|
||||
self.assertAllClose(expected.eval(), self.evaluate(result))
|
||||
result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 2), result.get_shape())
|
||||
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_static_dims_broadcast_rhs_has_extra_dims_dynamic(self):
|
||||
# Since the second arg has extra dims, and the domain dim of the first arg
|
||||
# is larger than the number of linear equations, code will "flip" the extra
|
||||
@ -290,22 +260,14 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
|
||||
rhs = rng.rand(2, 3, 2)
|
||||
matrix_broadcast = matrix + np.zeros((2, 1, 1))
|
||||
|
||||
matrix_ph = array_ops.placeholder(dtypes.float64, shape=[None, None])
|
||||
rhs_ph = array_ops.placeholder(dtypes.float64, shape=[None, None, None])
|
||||
matrix_ph = array_ops.placeholder_with_default(matrix, shape=[None, None])
|
||||
rhs_ph = array_ops.placeholder_with_default(rhs, shape=[None, None, None])
|
||||
|
||||
with self.cached_session():
|
||||
result = linear_operator_util.matrix_solve_with_broadcast(matrix_ph,
|
||||
rhs_ph)
|
||||
self.assertAllEqual(3, result.shape.ndims)
|
||||
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
|
||||
self.assertAllClose(
|
||||
self.evaluate(expected),
|
||||
result.eval(feed_dict={
|
||||
matrix_ph: matrix,
|
||||
rhs_ph: rhs
|
||||
}))
|
||||
result = linear_operator_util.matrix_solve_with_broadcast(matrix_ph, rhs_ph)
|
||||
self.assertAllEqual(3, result.shape.ndims)
|
||||
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_static_dims_broadcast_rhs_has_extra_dims_and_adjoint(self):
|
||||
# Since the second arg has extra dims, and the domain dim of the first arg
|
||||
# is larger than the number of linear equations, code will "flip" the extra
|
||||
@ -318,14 +280,12 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
|
||||
rhs = rng.rand(2, 3, 2)
|
||||
matrix_broadcast = matrix + np.zeros((2, 1, 1))
|
||||
|
||||
with self.cached_session():
|
||||
result = linear_operator_util.matrix_solve_with_broadcast(
|
||||
matrix, rhs, adjoint=True)
|
||||
self.assertAllEqual((2, 3, 2), result.get_shape())
|
||||
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs, adjoint=True)
|
||||
self.assertAllClose(expected.eval(), self.evaluate(result))
|
||||
result = linear_operator_util.matrix_solve_with_broadcast(
|
||||
matrix, rhs, adjoint=True)
|
||||
self.assertAllEqual((2, 3, 2), result.get_shape())
|
||||
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs, adjoint=True)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_dynamic_dims_broadcast_64bit(self):
|
||||
# batch_shape = [2, 2]
|
||||
matrix = rng.rand(2, 3, 3)
|
||||
@ -333,40 +293,30 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
|
||||
matrix_broadcast = matrix + np.zeros((2, 2, 1, 1))
|
||||
rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))
|
||||
|
||||
matrix_ph = array_ops.placeholder(dtypes.float64)
|
||||
rhs_ph = array_ops.placeholder(dtypes.float64)
|
||||
matrix_ph = array_ops.placeholder_with_default(matrix, shape=None)
|
||||
rhs_ph = array_ops.placeholder_with_default(rhs, shape=None)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result, expected = sess.run(
|
||||
[
|
||||
linear_operator_util.matrix_solve_with_broadcast(
|
||||
matrix_ph, rhs_ph),
|
||||
linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
|
||||
],
|
||||
feed_dict={
|
||||
matrix_ph: matrix,
|
||||
rhs_ph: rhs,
|
||||
})
|
||||
self.assertAllClose(expected, result)
|
||||
result, expected = self.evaluate([
|
||||
linear_operator_util.matrix_solve_with_broadcast(matrix_ph, rhs_ph),
|
||||
linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
|
||||
])
|
||||
self.assertAllClose(expected, result)
|
||||
|
||||
|
||||
class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_static_dims_broadcast_matrix_has_extra_dims(self):
|
||||
# batch_shape = [2]
|
||||
matrix = rng.rand(2, 3, 3)
|
||||
rhs = rng.rand(3, 7)
|
||||
rhs_broadcast = rhs + np.zeros((2, 1, 1))
|
||||
|
||||
with self.cached_session():
|
||||
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 7), result.get_shape())
|
||||
expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
|
||||
self.assertAllClose(expected.eval(), self.evaluate(result))
|
||||
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 7), result.get_shape())
|
||||
expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_static_dims_broadcast_rhs_has_extra_dims(self):
|
||||
# Since the second arg has extra dims, and the domain dim of the first arg
|
||||
# is larger than the number of linear equations, code will "flip" the extra
|
||||
@ -379,14 +329,12 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
|
||||
rhs = rng.rand(2, 3, 2)
|
||||
matrix_broadcast = matrix + np.zeros((2, 1, 1))
|
||||
|
||||
with self.cached_session():
|
||||
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 2), result.get_shape())
|
||||
expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs)
|
||||
self.assertAllClose(expected.eval(), self.evaluate(result))
|
||||
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 2), result.get_shape())
|
||||
expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_static_dims_broadcast_rhs_has_extra_dims_and_adjoint(self):
|
||||
# Since the second arg has extra dims, and the domain dim of the first arg
|
||||
# is larger than the number of linear equations, code will "flip" the extra
|
||||
@ -399,36 +347,28 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
|
||||
rhs = rng.rand(2, 3, 2)
|
||||
matrix_broadcast = matrix + np.zeros((2, 1, 1))
|
||||
|
||||
with self.cached_session():
|
||||
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix, rhs, adjoint=True)
|
||||
self.assertAllEqual((2, 3, 2), result.get_shape())
|
||||
expected = linalg_ops.matrix_triangular_solve(
|
||||
matrix_broadcast, rhs, adjoint=True)
|
||||
self.assertAllClose(expected.eval(), self.evaluate(result))
|
||||
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix, rhs, adjoint=True)
|
||||
self.assertAllEqual((2, 3, 2), result.get_shape())
|
||||
expected = linalg_ops.matrix_triangular_solve(
|
||||
matrix_broadcast, rhs, adjoint=True)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_dynamic_dims_broadcast_64bit(self):
|
||||
# batch_shape = [2]
|
||||
matrix = rng.rand(2, 3, 3)
|
||||
rhs = rng.rand(3, 7)
|
||||
rhs_broadcast = rhs + np.zeros((2, 1, 1))
|
||||
|
||||
matrix_ph = array_ops.placeholder(dtypes.float64)
|
||||
rhs_ph = array_ops.placeholder(dtypes.float64)
|
||||
matrix_ph = array_ops.placeholder_with_default(matrix, shape=None)
|
||||
rhs_ph = array_ops.placeholder_with_default(rhs, shape=None)
|
||||
|
||||
with self.cached_session() as sess:
|
||||
result, expected = sess.run(
|
||||
[
|
||||
linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix_ph, rhs_ph),
|
||||
linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
|
||||
],
|
||||
feed_dict={
|
||||
matrix_ph: matrix,
|
||||
rhs_ph: rhs,
|
||||
})
|
||||
self.assertAllClose(expected, result)
|
||||
result, expected = self.evaluate([
|
||||
linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix_ph, rhs_ph),
|
||||
linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
|
||||
])
|
||||
self.assertAllClose(expected, result)
|
||||
|
||||
|
||||
class DomainDimensionStubOperator(object):
|
||||
@ -442,22 +382,21 @@ class DomainDimensionStubOperator(object):
|
||||
|
||||
class AssertCompatibleMatrixDimensionsTest(test.TestCase):
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_compatible_dimensions_do_not_raise(self):
|
||||
with self.cached_session():
|
||||
x = ops.convert_to_tensor(rng.rand(2, 3, 4))
|
||||
operator = DomainDimensionStubOperator(3)
|
||||
# Should not raise
|
||||
linear_operator_util.assert_compatible_matrix_dimensions(
|
||||
operator, x).run() # pyformat: disable
|
||||
x = ops.convert_to_tensor(rng.rand(2, 3, 4))
|
||||
operator = DomainDimensionStubOperator(3)
|
||||
# Should not raise
|
||||
self.evaluate(
|
||||
linear_operator_util.assert_compatible_matrix_dimensions(operator, x))
|
||||
|
||||
def test_incompatible_dimensions_raise(self):
|
||||
with self.cached_session():
|
||||
x = ops.convert_to_tensor(rng.rand(2, 4, 4))
|
||||
operator = DomainDimensionStubOperator(3)
|
||||
with self.assertRaisesOpError("Incompatible matrix dimensions"):
|
||||
linear_operator_util.assert_compatible_matrix_dimensions(
|
||||
operator, x).run() # pyformat: disable
|
||||
x = ops.convert_to_tensor(rng.rand(2, 4, 4))
|
||||
operator = DomainDimensionStubOperator(3)
|
||||
# pylint: disable=g-error-prone-assert-raises
|
||||
with self.assertRaisesOpError("Dimensions are not compatible"):
|
||||
self.evaluate(
|
||||
linear_operator_util.assert_compatible_matrix_dimensions(operator, x))
|
||||
# pylint: enable=g-error-prone-assert-raises
|
||||
|
||||
|
||||
class DummyOperatorWithHint(object):
|
||||
|
@ -22,6 +22,7 @@ import numpy as np
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.ops.linalg import linalg as linalg_lib
|
||||
from tensorflow.python.ops.linalg import linear_operator_test_util
|
||||
from tensorflow.python.platform import test
|
||||
@ -30,6 +31,7 @@ from tensorflow.python.platform import test
|
||||
rng = np.random.RandomState(2016)
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LinearOperatorZerosTest(
|
||||
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
|
||||
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
|
||||
@ -75,11 +77,10 @@ class LinearOperatorZerosTest(
|
||||
operator = linalg_lib.LinearOperatorZeros(num_rows=2)
|
||||
operator.assert_non_singular()
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_assert_self_adjoint(self):
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorZeros(num_rows=2)
|
||||
operator.assert_self_adjoint().run() # Should not fail
|
||||
self.evaluate(operator.assert_self_adjoint()) # Should not fail
|
||||
|
||||
def test_non_scalar_num_rows_raises_static(self):
|
||||
with self.assertRaisesRegexp(ValueError, "must be a 0-D Tensor"):
|
||||
@ -111,46 +112,37 @@ class LinearOperatorZerosTest(
|
||||
with self.assertRaisesRegexp(ValueError, "must be non-negative"):
|
||||
linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[-2])
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_non_scalar_num_rows_raises_dynamic(self):
|
||||
with self.cached_session():
|
||||
num_rows = array_ops.placeholder(dtypes.int32)
|
||||
operator = linalg_lib.LinearOperatorZeros(
|
||||
num_rows, assert_proper_shapes=True)
|
||||
with self.assertRaisesOpError("must be a 0-D Tensor"):
|
||||
operator.to_dense().eval(feed_dict={num_rows: [2]})
|
||||
num_rows = array_ops.placeholder_with_default([2], shape=None)
|
||||
with self.assertRaisesError("must be a 0-D Tensor"):
|
||||
operator = linalg_lib.LinearOperatorZeros(
|
||||
num_rows, assert_proper_shapes=True)
|
||||
self.evaluate(operator.to_dense())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_negative_num_rows_raises_dynamic(self):
|
||||
with self.cached_session():
|
||||
n = array_ops.placeholder(dtypes.int32)
|
||||
operator = linalg_lib.LinearOperatorZeros(
|
||||
num_rows=n, assert_proper_shapes=True)
|
||||
with self.assertRaisesOpError("must be non-negative"):
|
||||
operator.to_dense().eval(feed_dict={n: -2})
|
||||
n = array_ops.placeholder_with_default(-2, shape=None)
|
||||
with self.assertRaisesError("must be non-negative"):
|
||||
operator = linalg_lib.LinearOperatorZeros(
|
||||
num_rows=n, assert_proper_shapes=True)
|
||||
self.evaluate(operator.to_dense())
|
||||
|
||||
operator = linalg_lib.LinearOperatorZeros(
|
||||
num_rows=2, num_columns=n, assert_proper_shapes=True)
|
||||
with self.assertRaisesOpError("must be non-negative"):
|
||||
operator.to_dense().eval(feed_dict={n: -2})
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_non_1d_batch_shape_raises_dynamic(self):
|
||||
with self.cached_session():
|
||||
batch_shape = array_ops.placeholder(dtypes.int32)
|
||||
operator = linalg_lib.LinearOperatorZeros(
|
||||
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
|
||||
with self.assertRaisesOpError("must be a 1-D"):
|
||||
operator.to_dense().eval(feed_dict={batch_shape: 2})
|
||||
batch_shape = array_ops.placeholder_with_default(2, shape=None)
|
||||
with self.assertRaisesError("must be a 1-D"):
|
||||
operator = linalg_lib.LinearOperatorZeros(
|
||||
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
|
||||
self.evaluate(operator.to_dense())
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_negative_batch_shape_raises_dynamic(self):
|
||||
with self.cached_session():
|
||||
batch_shape = array_ops.placeholder(dtypes.int32)
|
||||
operator = linalg_lib.LinearOperatorZeros(
|
||||
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
|
||||
with self.assertRaisesOpError("must be non-negative"):
|
||||
operator.to_dense().eval(feed_dict={batch_shape: [-2]})
|
||||
batch_shape = array_ops.placeholder_with_default([-2], shape=None)
|
||||
with self.assertRaisesError("must be non-negative"):
|
||||
operator = linalg_lib.LinearOperatorZeros(
|
||||
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
|
||||
self.evaluate(operator.to_dense())
|
||||
|
||||
def test_wrong_matrix_dimensions_raises_static(self):
|
||||
operator = linalg_lib.LinearOperatorZeros(num_rows=2)
|
||||
@ -158,17 +150,15 @@ class LinearOperatorZerosTest(
|
||||
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
|
||||
operator.matmul(x)
|
||||
|
||||
@test_util.run_deprecated_v1
|
||||
def test_wrong_matrix_dimensions_raises_dynamic(self):
|
||||
num_rows = array_ops.placeholder(dtypes.int32)
|
||||
x = array_ops.placeholder(dtypes.float32)
|
||||
num_rows = array_ops.placeholder_with_default(2, shape=None)
|
||||
x = array_ops.placeholder_with_default(rng.rand(3, 3), shape=None)
|
||||
|
||||
with self.cached_session():
|
||||
operator = linalg_lib.LinearOperatorZeros(
|
||||
num_rows, assert_proper_shapes=True)
|
||||
y = operator.matmul(x)
|
||||
with self.assertRaisesOpError("Incompatible.*dimensions"):
|
||||
y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)})
|
||||
with self.assertRaisesError("Dimensions.*not.compatible"):
|
||||
operator = linalg_lib.LinearOperatorZeros(
|
||||
num_rows, assert_proper_shapes=True, dtype=dtypes.float64)
|
||||
self.evaluate(operator.matmul(x))
|
||||
|
||||
def test_is_x_flags(self):
|
||||
# The is_x flags are by default all True.
|
||||
@ -188,7 +178,20 @@ class LinearOperatorZerosTest(
|
||||
operator2.matmul(operator1),
|
||||
linalg_lib.LinearOperatorZeros))
|
||||
|
||||
def test_ref_type_shape_args_raises(self):
|
||||
with self.assertRaisesRegexp(TypeError, "num_rows.cannot.be.reference"):
|
||||
linalg_lib.LinearOperatorZeros(num_rows=variables_module.Variable(2))
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "num_columns.cannot.be.reference"):
|
||||
linalg_lib.LinearOperatorZeros(
|
||||
num_rows=2, num_columns=variables_module.Variable(3))
|
||||
|
||||
with self.assertRaisesRegexp(TypeError, "batch_shape.cannot.be.reference"):
|
||||
linalg_lib.LinearOperatorZeros(
|
||||
num_rows=2, batch_shape=variables_module.Variable([2]))
|
||||
|
||||
|
||||
@test_util.run_all_in_graph_and_eager_modes
|
||||
class LinearOperatorZerosNotSquareTest(
|
||||
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
|
||||
|
||||
|
@ -25,6 +25,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.linalg import linalg_impl as linalg
|
||||
from tensorflow.python.ops.linalg import linear_operator
|
||||
from tensorflow.python.ops.linalg import linear_operator_util
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
__all__ = ["LinearOperatorHouseholder",]
|
||||
@ -123,7 +124,7 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
|
||||
"""
|
||||
|
||||
with ops.name_scope(name, values=[reflection_axis]):
|
||||
self._reflection_axis = ops.convert_to_tensor(
|
||||
self._reflection_axis = linear_operator_util.convert_nonref_to_tensor(
|
||||
reflection_axis, name="reflection_axis")
|
||||
self._check_reflection_axis(self._reflection_axis)
|
||||
|
||||
@ -194,9 +195,10 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
|
||||
|
||||
# Note that because this is a reflection, it lies in O(n) (for real vector
|
||||
# spaces) or U(n) (for complex vector spaces), and thus is its own adjoint.
|
||||
reflection_axis = ops.convert_to_tensor(self.reflection_axis)
|
||||
x = linalg.adjoint(x) if adjoint_arg else x
|
||||
normalized_axis = self.reflection_axis / linalg.norm(
|
||||
self.reflection_axis, axis=-1, keepdims=True)
|
||||
normalized_axis = reflection_axis / linalg.norm(
|
||||
reflection_axis, axis=-1, keepdims=True)
|
||||
mat = normalized_axis[..., array_ops.newaxis]
|
||||
x_dot_normalized_v = math_ops.matmul(mat, x, adjoint_a=True)
|
||||
|
||||
|
@ -250,6 +250,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
negative.
|
||||
ValueError: If any of the following is not `True`:
|
||||
`{is_self_adjoint, is_non_singular, is_positive_definite}`.
|
||||
TypeError: If `num_rows` or `batch_shape` is ref-type (e.g. Variable).
|
||||
"""
|
||||
dtype = dtype or dtypes.float32
|
||||
self._assert_proper_shapes = assert_proper_shapes
|
||||
@ -273,6 +274,9 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
|
||||
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
|
||||
linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape")
|
||||
|
||||
self._num_rows = linear_operator_util.shape_tensor(
|
||||
num_rows, name="num_rows")
|
||||
self._num_rows_static = tensor_util.constant_value(self._num_rows)
|
||||
@ -589,7 +593,8 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
self._assert_proper_shapes = assert_proper_shapes
|
||||
|
||||
with ops.name_scope(name, values=[multiplier, num_rows]):
|
||||
self._multiplier = ops.convert_to_tensor(multiplier, name="multiplier")
|
||||
self._multiplier = linear_operator_util.convert_nonref_to_tensor(
|
||||
multiplier, name="multiplier")
|
||||
|
||||
# Check and auto-set hints.
|
||||
if not self._multiplier.dtype.is_complex:
|
||||
@ -601,20 +606,16 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
if not is_square:
|
||||
raise ValueError("A ScaledIdentity operator is always square.")
|
||||
|
||||
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
|
||||
|
||||
super(LinearOperatorScaledIdentity, self).__init__(
|
||||
dtype=self._multiplier.dtype,
|
||||
dtype=self._multiplier.dtype.base_dtype,
|
||||
is_non_singular=is_non_singular,
|
||||
is_self_adjoint=is_self_adjoint,
|
||||
is_positive_definite=is_positive_definite,
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
|
||||
# Shape [B1,...Bb, 1, 1]
|
||||
self._multiplier_matrix = array_ops.expand_dims(
|
||||
array_ops.expand_dims(self.multiplier, -1), -1)
|
||||
self._multiplier_matrix_conj = math_ops.conj(self._multiplier_matrix)
|
||||
self._abs_multiplier = math_ops.abs(self.multiplier)
|
||||
|
||||
self._num_rows = linear_operator_util.shape_tensor(
|
||||
num_rows, name="num_rows")
|
||||
self._num_rows_static = tensor_util.constant_value(self._num_rows)
|
||||
@ -652,34 +653,34 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
|
||||
imag_multiplier,
|
||||
message="LinearOperator was not self-adjoint")
|
||||
|
||||
def _make_multiplier_matrix(self, conjugate=False):
|
||||
# Shape [B1,...Bb, 1, 1]
|
||||
multiplier_matrix = array_ops.expand_dims(
|
||||
array_ops.expand_dims(self.multiplier, -1), -1)
|
||||
if conjugate:
|
||||
multiplier_matrix = math_ops.conj(multiplier_matrix)
|
||||
return multiplier_matrix
|
||||
|
||||
def _matmul(self, x, adjoint=False, adjoint_arg=False):
|
||||
x = linalg.adjoint(x) if adjoint_arg else x
|
||||
if adjoint:
|
||||
matrix = self._multiplier_matrix_conj
|
||||
else:
|
||||
matrix = self._multiplier_matrix
|
||||
if self._assert_proper_shapes:
|
||||
aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
|
||||
x = control_flow_ops.with_dependencies([aps], x)
|
||||
return x * matrix
|
||||
return x * self._make_multiplier_matrix(conjugate=adjoint)
|
||||
|
||||
def _determinant(self):
|
||||
return self.multiplier**self._num_rows_cast_to_dtype
|
||||
|
||||
def _log_abs_determinant(self):
|
||||
return self._num_rows_cast_to_real_dtype * math_ops.log(
|
||||
self._abs_multiplier)
|
||||
math_ops.abs(self.multiplier))
|
||||
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
|
||||
if adjoint:
|
||||
matrix = self._multiplier_matrix_conj
|
||||
else:
|
||||
matrix = self._multiplier_matrix
|
||||
if self._assert_proper_shapes:
|
||||
aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs)
|
||||
rhs = control_flow_ops.with_dependencies([aps], rhs)
|
||||
return rhs / matrix
|
||||
return rhs / self._make_multiplier_matrix(conjugate=adjoint)
|
||||
|
||||
def _trace(self):
|
||||
# Get Tensor of all ones of same shape as self.batch_shape.
|
||||
|
@ -18,7 +18,6 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -145,10 +144,9 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
|
||||
is_square = True
|
||||
|
||||
with ops.name_scope(name, values=[tril]):
|
||||
self._tril = ops.convert_to_tensor(tril, name="tril")
|
||||
self._tril = linear_operator_util.convert_nonref_to_tensor(tril,
|
||||
name="tril")
|
||||
self._check_tril(self._tril)
|
||||
self._tril = array_ops.matrix_band_part(tril, -1, 0)
|
||||
self._diag = array_ops.matrix_diag_part(self._tril)
|
||||
|
||||
super(LinearOperatorLowerTriangular, self).__init__(
|
||||
dtype=self._tril.dtype,
|
||||
@ -161,24 +159,20 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
|
||||
|
||||
def _check_tril(self, tril):
|
||||
"""Static check of the `tril` argument."""
|
||||
allowed_dtypes = [
|
||||
dtypes.float16,
|
||||
dtypes.float32,
|
||||
dtypes.float64,
|
||||
dtypes.complex64,
|
||||
dtypes.complex128,
|
||||
]
|
||||
dtype = tril.dtype
|
||||
if dtype not in allowed_dtypes:
|
||||
raise TypeError(
|
||||
"Argument tril must have dtype in %s. Found: %s"
|
||||
% (allowed_dtypes, dtype))
|
||||
|
||||
if tril.get_shape().ndims is not None and tril.get_shape().ndims < 2:
|
||||
raise ValueError(
|
||||
"Argument tril must have at least 2 dimensions. Found: %s"
|
||||
% tril)
|
||||
|
||||
def _get_tril(self):
|
||||
"""Gets the `tril` kwarg, with upper part zero-d out."""
|
||||
return array_ops.matrix_band_part(self._tril, -1, 0)
|
||||
|
||||
def _get_diag(self):
|
||||
"""Gets the diagonal part of `tril` kwarg."""
|
||||
return array_ops.matrix_diag_part(self._tril)
|
||||
|
||||
def _shape(self):
|
||||
return self._tril.get_shape()
|
||||
|
||||
@ -187,27 +181,24 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
|
||||
|
||||
def _assert_non_singular(self):
|
||||
return linear_operator_util.assert_no_entries_with_modulus_zero(
|
||||
self._diag,
|
||||
self._get_diag(),
|
||||
message="Singular operator: Diagonal contained zero values.")
|
||||
|
||||
def _matmul(self, x, adjoint=False, adjoint_arg=False):
|
||||
return math_ops.matmul(
|
||||
self._tril, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
|
||||
self._get_tril(), x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
|
||||
|
||||
def _determinant(self):
|
||||
return math_ops.reduce_prod(self._diag, axis=[-1])
|
||||
return math_ops.reduce_prod(self._get_diag(), axis=[-1])
|
||||
|
||||
def _log_abs_determinant(self):
|
||||
return math_ops.reduce_sum(
|
||||
math_ops.log(math_ops.abs(self._diag)), axis=[-1])
|
||||
math_ops.log(math_ops.abs(self._get_diag())), axis=[-1])
|
||||
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
|
||||
return linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
self._tril, rhs, lower=True, adjoint=adjoint)
|
||||
self._get_tril(), rhs, lower=True, adjoint=adjoint)
|
||||
|
||||
def _to_dense(self):
|
||||
return self._tril
|
||||
|
||||
def _add_to_tensor(self, x):
|
||||
return self._tril + x
|
||||
return self._get_tril()
|
||||
|
@ -24,6 +24,7 @@ import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
@ -51,6 +52,15 @@ class OperatorShapesInfo(object):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
|
||||
class CheckTapeSafeSkipOptions(object):
|
||||
|
||||
# Skip checking this particular method.
|
||||
DETERMINANT = "determinant"
|
||||
DIAG_PART = "diag_part"
|
||||
LOG_ABS_DETERMINANT = "log_abs_determinant"
|
||||
TRACE = "trace"
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta) # pylint: disable=no-init
|
||||
class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
"""Tests for derived classes.
|
||||
@ -174,18 +184,35 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
# To skip "test_foo", add "foo" to this list.
|
||||
return []
|
||||
|
||||
def check_tape_safe(self, operator):
|
||||
"""Check gradients are not None w.r.t. Variables.
|
||||
def assertRaisesError(self, msg):
|
||||
"""assertRaisesRegexp or OpError, depending on context.executing_eagerly."""
|
||||
if context.executing_eagerly():
|
||||
return self.assertRaisesRegexp(Exception, msg)
|
||||
return self.assertRaisesOpError(msg)
|
||||
|
||||
def check_tape_safe(self, operator, skip_options=None):
|
||||
"""Check gradients are not None w.r.t. operator.variables.
|
||||
|
||||
Meant to be called from the derived class.
|
||||
|
||||
This ensures grads are not w.r.t every variable in operator.variables. If
|
||||
more fine-grained testing is needed, a custom test should be written.
|
||||
|
||||
Args:
|
||||
operator: LinearOperator. Exact checks done will depend on hints.
|
||||
skip_options: Optional list of CheckTapeSafeSkipOptions.
|
||||
Makes this test skip particular checks.
|
||||
"""
|
||||
skip_options = skip_options or []
|
||||
|
||||
if not operator.variables:
|
||||
raise AssertionError("`operator.variables` was empty")
|
||||
|
||||
def _assert_not_none(iterable):
|
||||
for item in iterable:
|
||||
self.assertIsNotNone(item)
|
||||
|
||||
# Tape tests that can be run on every operator below.
|
||||
with backprop.GradientTape() as tape:
|
||||
_assert_not_none(tape.gradient(operator.to_dense(), operator.variables))
|
||||
|
||||
@ -193,23 +220,30 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
_assert_not_none(
|
||||
tape.gradient(operator.adjoint().to_dense(), operator.variables))
|
||||
|
||||
x = array_ops.ones(shape=operator.H.shape_tensor()[:-1])
|
||||
x = math_ops.cast(
|
||||
array_ops.ones(shape=operator.H.shape_tensor()[:-1]), operator.dtype)
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
_assert_not_none(tape.gradient(operator.matvec(x), operator.variables))
|
||||
|
||||
# Tests for square, but possibly non-singular operators below.
|
||||
if not operator.is_square:
|
||||
return
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
_assert_not_none(
|
||||
tape.gradient(operator.determinant(), operator.variables))
|
||||
for option in [
|
||||
CheckTapeSafeSkipOptions.DETERMINANT,
|
||||
CheckTapeSafeSkipOptions.LOG_ABS_DETERMINANT,
|
||||
CheckTapeSafeSkipOptions.DIAG_PART,
|
||||
CheckTapeSafeSkipOptions.TRACE,
|
||||
]:
|
||||
with backprop.GradientTape() as tape:
|
||||
if option not in skip_options:
|
||||
_assert_not_none(
|
||||
tape.gradient(getattr(operator, option)(), operator.variables))
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
_assert_not_none(tape.gradient(operator.diag_part(), operator.variables))
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
_assert_not_none(tape.gradient(operator.trace(), operator.variables))
|
||||
# Tests for non-singular operators below.
|
||||
if operator.is_non_singular is False: # pylint: disable=g-bool-id-comparison
|
||||
return
|
||||
|
||||
with backprop.GradientTape() as tape:
|
||||
_assert_not_none(
|
||||
@ -218,6 +252,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
with backprop.GradientTape() as tape:
|
||||
_assert_not_none(tape.gradient(operator.solvevec(x), operator.variables))
|
||||
|
||||
# Tests for SPD operators below.
|
||||
if not (operator.is_self_adjoint and operator.is_positive_definite):
|
||||
return
|
||||
|
||||
|
@ -157,6 +157,12 @@ def is_ref(x):
|
||||
hasattr(x, "shape")))
|
||||
|
||||
|
||||
def assert_not_ref_type(x, arg_name):
|
||||
if is_ref(x):
|
||||
raise TypeError(
|
||||
"Argument %s cannot be reference type. Found: %s" % (arg_name, type(x)))
|
||||
|
||||
|
||||
################################################################################
|
||||
# Asserts.
|
||||
################################################################################
|
||||
@ -223,7 +229,9 @@ def assert_compatible_matrix_dimensions(operator, x):
|
||||
assert_same_dd = check_ops.assert_equal(
|
||||
array_ops.shape(x)[-2],
|
||||
operator.domain_dimension_tensor(),
|
||||
message=("Incompatible matrix dimensions. "
|
||||
# This error message made to look similar to error raised by static check
|
||||
# in the base class.
|
||||
message=("Dimensions are not compatible. "
|
||||
"shape[-2] of argument to be the same as this operator"))
|
||||
|
||||
return assert_same_dd
|
||||
|
@ -196,6 +196,10 @@ class LinearOperatorZeros(linear_operator.LinearOperator):
|
||||
is_square=is_square,
|
||||
name=name)
|
||||
|
||||
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
|
||||
linear_operator_util.assert_not_ref_type(num_columns, "num_columns")
|
||||
linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape")
|
||||
|
||||
self._num_rows = linear_operator_util.shape_tensor(
|
||||
num_rows, name="num_rows")
|
||||
self._num_rows_static = tensor_util.constant_value(self._num_rows)
|
||||
|
Loading…
Reference in New Issue
Block a user