Make non-meta linear operators (other than Circulant/Toeplitz) tape safe.

PiperOrigin-RevId: 259453506
This commit is contained in:
Ian Langmore 2019-07-22 19:05:06 -07:00 committed by TensorFlower Gardener
parent df6ba21e45
commit 95bcd434d0
13 changed files with 388 additions and 350 deletions

View File

@ -382,7 +382,7 @@ class WishartCholeskyTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "cannot be less than"):
distributions.WishartCholesky(
df=2, scale=chol_scale, validate_args=False)
with self.assertRaisesRegexp(TypeError, "Argument tril must have dtype"):
with self.assertRaisesRegexp(TypeError, "."):
distributions.WishartCholesky(
df=4.,
scale=np.asarray(

View File

@ -20,7 +20,9 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.ops.linalg import linear_operator_block_diag as block_diag
from tensorflow.python.ops.linalg import linear_operator_lower_triangular as lower_triangular
@ -56,6 +58,7 @@ def _block_diag_dense(expected_shape, blocks):
return array_ops.concat(rows, axis=-2)
@test_util.run_all_in_graph_and_eager_modes
class SquareLinearOperatorBlockDiagTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
@ -209,6 +212,26 @@ class SquareLinearOperatorBlockDiagTest(
block_diag.LinearOperatorBlockDiag)
self.assertEqual(2, len(inverse.operators))
def test_tape_safe(self):
matrix = variables_module.Variable([[1., 0.], [0., 1.]])
operator = block_diag.LinearOperatorBlockDiag(
[
linalg.LinearOperatorFullMatrix(
matrix,
is_self_adjoint=True,
is_positive_definite=True,
),
linalg.LinearOperatorFullMatrix(
matrix,
is_self_adjoint=True,
is_positive_definite=True,
),
],
is_self_adjoint=True,
is_positive_definite=True,
)
self.check_tape_safe(operator)
def test_is_non_singular_auto_set(self):
# Matrix with two positive eigenvalues, 11 and 8.
# The matrix values do not effect auto-setting of the flags.

View File

@ -17,17 +17,21 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.ops.linalg import linear_operator_householder as householder
from tensorflow.python.ops.linalg import linear_operator_test_util
from tensorflow.python.platform import test
linalg = linalg_lib
CheckTapeSafeSkipOptions = linear_operator_test_util.CheckTapeSafeSkipOptions
@test_util.run_all_in_graph_and_eager_modes
class LinearOperatorHouseholderTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
@ -87,6 +91,19 @@ class LinearOperatorHouseholderTest(
self.assertIsInstance(
operator.inverse(), householder.LinearOperatorHouseholder)
def test_tape_safe(self):
reflection_axis = variables_module.Variable([1., 3., 5., 8.])
operator = householder.LinearOperatorHouseholder(reflection_axis)
self.check_tape_safe(
operator,
skip_options=[
# Determinant hard-coded as 1.
CheckTapeSafeSkipOptions.DETERMINANT,
CheckTapeSafeSkipOptions.LOG_ABS_DETERMINANT,
# Trace hard-coded.
CheckTapeSafeSkipOptions.TRACE,
])
if __name__ == "__main__":
linear_operator_test_util.add_tests(LinearOperatorHouseholderTest)

View File

@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.ops.linalg import linear_operator_test_util
from tensorflow.python.platform import test
@ -33,6 +34,7 @@ from tensorflow.python.platform import test
rng = np.random.RandomState(2016)
@test_util.run_all_in_graph_and_eager_modes
class LinearOperatorIdentityTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
@ -61,23 +63,20 @@ class LinearOperatorIdentityTest(
return operator, mat
@test_util.run_deprecated_v1
def test_assert_positive_definite(self):
with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
operator.assert_positive_definite().run() # Should not fail
self.evaluate(operator.assert_positive_definite()) # Should not fail
@test_util.run_deprecated_v1
def test_assert_non_singular(self):
with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
operator.assert_non_singular().run() # Should not fail
self.evaluate(operator.assert_non_singular()) # Should not fail
@test_util.run_deprecated_v1
def test_assert_self_adjoint(self):
with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
operator.assert_self_adjoint().run() # Should not fail
self.evaluate(operator.assert_self_adjoint()) # Should not fail
def test_float16_matmul(self):
# float16 cannot be tested by base test class because tf.linalg.solve does
@ -113,41 +112,38 @@ class LinearOperatorIdentityTest(
with self.assertRaisesRegexp(ValueError, "must be non-negative"):
linalg_lib.LinearOperatorIdentity(num_rows=2, batch_shape=[-2])
@test_util.run_deprecated_v1
def test_non_scalar_num_rows_raises_dynamic(self):
with self.cached_session():
num_rows = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorIdentity(
num_rows, assert_proper_shapes=True)
with self.assertRaisesOpError("must be a 0-D Tensor"):
operator.to_dense().eval(feed_dict={num_rows: [2]})
num_rows = array_ops.placeholder_with_default([2], shape=None)
with self.assertRaisesError("must be a 0-D Tensor"):
operator = linalg_lib.LinearOperatorIdentity(
num_rows, assert_proper_shapes=True)
self.evaluate(operator.to_dense())
@test_util.run_deprecated_v1
def test_negative_num_rows_raises_dynamic(self):
with self.cached_session():
num_rows = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorIdentity(
num_rows, assert_proper_shapes=True)
with self.assertRaisesOpError("must be non-negative"):
operator.to_dense().eval(feed_dict={num_rows: -2})
num_rows = array_ops.placeholder_with_default(-2, shape=None)
with self.assertRaisesError("must be non-negative"):
operator = linalg_lib.LinearOperatorIdentity(
num_rows, assert_proper_shapes=True)
self.evaluate(operator.to_dense())
@test_util.run_deprecated_v1
def test_non_1d_batch_shape_raises_dynamic(self):
with self.cached_session():
batch_shape = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorIdentity(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
with self.assertRaisesOpError("must be a 1-D"):
operator.to_dense().eval(feed_dict={batch_shape: 2})
batch_shape = array_ops.placeholder_with_default(2, shape=None)
with self.assertRaisesError("must be a 1-D"):
operator = linalg_lib.LinearOperatorIdentity(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
self.evaluate(operator.to_dense())
@test_util.run_deprecated_v1
def test_negative_batch_shape_raises_dynamic(self):
with self.cached_session():
batch_shape = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorIdentity(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
with self.assertRaisesOpError("must be non-negative"):
operator.to_dense().eval(feed_dict={batch_shape: [-2]})
batch_shape = array_ops.placeholder_with_default([-2], shape=None)
with self.assertRaisesError("must be non-negative"):
operator = linalg_lib.LinearOperatorIdentity(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
self.evaluate(operator.to_dense())
def test_wrong_matrix_dimensions_raises_static(self):
operator = linalg_lib.LinearOperatorIdentity(num_rows=2)
@ -155,17 +151,16 @@ class LinearOperatorIdentityTest(
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
operator.matmul(x)
@test_util.run_deprecated_v1
def test_wrong_matrix_dimensions_raises_dynamic(self):
num_rows = array_ops.placeholder(dtypes.int32)
x = array_ops.placeholder(dtypes.float32)
num_rows = array_ops.placeholder_with_default(2, shape=None)
x = array_ops.placeholder_with_default(
rng.rand(3, 3).astype(np.float32), shape=None)
with self.cached_session():
operator = linalg_lib.LinearOperatorIdentity(
num_rows, assert_proper_shapes=True)
y = operator.matmul(x)
with self.assertRaisesOpError("Incompatible.*dimensions"):
y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)})
with self.assertRaisesError("Dimensions.*not.compatible"):
operator = linalg_lib.LinearOperatorIdentity(
num_rows, assert_proper_shapes=True)
self.evaluate(operator.matmul(x))
def test_default_batch_shape_broadcasts_with_everything_static(self):
# These cannot be done in the automated (base test class) tests since they
@ -181,22 +176,18 @@ class LinearOperatorIdentityTest(
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
@test_util.run_deprecated_v1
def test_default_batch_shape_broadcasts_with_everything_dynamic(self):
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
with self.cached_session():
x = array_ops.placeholder_with_default(rng.randn(1, 2, 3, 4), shape=None)
operator = linalg_lib.LinearOperatorIdentity(num_rows=3, dtype=x.dtype)
operator_matmul = operator.matmul(x)
expected = x
feed_dict = {x: rng.randn(1, 2, 3, 4)}
self.assertAllClose(
*sess.run([operator_matmul, expected], feed_dict=feed_dict))
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
def test_broadcast_matmul_static_shapes(self):
# These cannot be done in the automated (base test class) tests since they
@ -219,21 +210,19 @@ class LinearOperatorIdentityTest(
self.assertAllEqual(operator_matmul.get_shape(), expected.get_shape())
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
@test_util.run_deprecated_v1
def test_broadcast_matmul_dynamic_shapes(self):
# These cannot be done in the automated (base test class) tests since they
# test shapes that tf.batch_matmul cannot handle.
# In particular, tf.batch_matmul does not broadcast.
with self.cached_session() as sess:
with self.cached_session():
# Given this x and LinearOperatorIdentity shape of (2, 1, 3, 3), the
# broadcast shape of operator and 'x' is (2, 2, 3, 4)
x = array_ops.placeholder(dtypes.float32)
num_rows = array_ops.placeholder(dtypes.int32)
batch_shape = array_ops.placeholder(dtypes.int32)
x = array_ops.placeholder_with_default(rng.rand(1, 2, 3, 4), shape=None)
num_rows = array_ops.placeholder_with_default(3, shape=None)
batch_shape = array_ops.placeholder_with_default((2, 1), shape=None)
operator = linalg_lib.LinearOperatorIdentity(
num_rows, batch_shape=batch_shape)
feed_dict = {x: rng.rand(1, 2, 3, 4), num_rows: 3, batch_shape: (2, 1)}
num_rows, batch_shape=batch_shape, dtype=dtypes.float64)
# Batch matrix of zeros with the broadcast shape of x and operator.
zeros = array_ops.zeros(shape=(2, 2, 3, 4), dtype=x.dtype)
@ -242,8 +231,7 @@ class LinearOperatorIdentityTest(
expected = x + zeros
operator_matmul = operator.matmul(x)
self.assertAllClose(
*sess.run([operator_matmul, expected], feed_dict=feed_dict))
self.assertAllClose(*self.evaluate([operator_matmul, expected]))
def test_is_x_flags(self):
# The is_x flags are by default all True.
@ -280,7 +268,16 @@ class LinearOperatorIdentityTest(
self.assertIsInstance(
operator.inverse(), linalg_lib.LinearOperatorIdentity)
def test_ref_type_shape_args_raises(self):
with self.assertRaisesRegexp(TypeError, "num_rows.*reference"):
linalg_lib.LinearOperatorIdentity(num_rows=variables_module.Variable(2))
with self.assertRaisesRegexp(TypeError, "batch_shape.*reference"):
linalg_lib.LinearOperatorIdentity(
num_rows=2, batch_shape=variables_module.Variable([3]))
@test_util.run_all_in_graph_and_eager_modes
class LinearOperatorScaledIdentityTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
@ -331,47 +328,44 @@ class LinearOperatorScaledIdentityTest(
return operator, matrix
@test_util.run_deprecated_v1
def test_assert_positive_definite_does_not_raise_when_positive(self):
with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=1.)
operator.assert_positive_definite().run() # Should not fail
self.evaluate(operator.assert_positive_definite()) # Should not fail
def test_assert_positive_definite_raises_when_negative(self):
with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=-1.)
with self.assertRaisesOpError("not positive definite"):
operator.assert_positive_definite().run()
self.evaluate(operator.assert_positive_definite())
@test_util.run_deprecated_v1
def test_assert_non_singular_does_not_raise_when_non_singular(self):
with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[1., 2., 3.])
operator.assert_non_singular().run() # Should not fail
self.evaluate(operator.assert_non_singular()) # Should not fail
def test_assert_non_singular_raises_when_singular(self):
with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[1., 2., 0.])
with self.assertRaisesOpError("was singular"):
operator.assert_non_singular().run()
self.evaluate(operator.assert_non_singular())
@test_util.run_deprecated_v1
def test_assert_self_adjoint_does_not_raise_when_self_adjoint(self):
with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[1. + 0J])
operator.assert_self_adjoint().run() # Should not fail
self.evaluate(operator.assert_self_adjoint()) # Should not fail
def test_assert_self_adjoint_raises_when_not_self_adjoint(self):
with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=[1. + 1J])
with self.assertRaisesOpError("not self-adjoint"):
operator.assert_self_adjoint().run()
self.evaluate(operator.assert_self_adjoint())
def test_float16_matmul(self):
# float16 cannot be tested by base test class because tf.linalg.solve does
@ -397,17 +391,18 @@ class LinearOperatorScaledIdentityTest(
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
operator.matmul(x)
@test_util.run_deprecated_v1
def test_wrong_matrix_dimensions_raises_dynamic(self):
num_rows = array_ops.placeholder(dtypes.int32)
x = array_ops.placeholder(dtypes.float32)
num_rows = array_ops.placeholder_with_default(2, shape=None)
x = array_ops.placeholder_with_default(
rng.rand(3, 3).astype(np.float32), shape=None)
with self.cached_session():
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows, multiplier=[1., 2], assert_proper_shapes=True)
y = operator.matmul(x)
with self.assertRaisesOpError("Incompatible.*dimensions"):
y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)})
with self.assertRaisesError("Dimensions.*not.compatible"):
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows,
multiplier=[1., 2],
assert_proper_shapes=True)
self.evaluate(operator.matmul(x))
def test_broadcast_matmul_and_solve(self):
# These cannot be done in the automated (base test class) tests since they
@ -530,6 +525,17 @@ class LinearOperatorScaledIdentityTest(
operator.inverse(),
linalg_lib.LinearOperatorScaledIdentity)
def test_ref_type_shape_args_raises(self):
with self.assertRaisesRegexp(TypeError, "num_rows.*reference"):
linalg_lib.LinearOperatorScaledIdentity(
num_rows=variables_module.Variable(2), multiplier=1.23)
def test_tape_safe(self):
multiplier = variables_module.Variable(1.23)
operator = linalg_lib.LinearOperatorScaledIdentity(
num_rows=2, multiplier=multiplier)
self.check_tape_safe(operator)
if __name__ == "__main__":
linear_operator_test_util.add_tests(LinearOperatorIdentityTest)

View File

@ -17,8 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.ops.linalg import linear_operator_test_util
from tensorflow.python.platform import test
@ -26,6 +28,7 @@ from tensorflow.python.platform import test
linalg = linalg_lib
@test_util.run_all_in_graph_and_eager_modes
class LinearOperatorLowerTriangularTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
@ -101,6 +104,12 @@ class LinearOperatorLowerTriangularTest(
operator1.to_dense()),
self.evaluate(operator_matmul.to_dense()))
def test_tape_safe(self):
tril = variables_module.Variable([[1., 0.], [0., 1.]])
operator = linalg_lib.LinearOperatorLowerTriangular(
tril, is_non_singular=True)
self.check_tape_safe(operator)
if __name__ == "__main__":
linear_operator_test_util.add_tests(LinearOperatorLowerTriangularTest)

View File

@ -20,9 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
@ -34,66 +32,62 @@ rng = np.random.RandomState(0)
class AssertZeroImagPartTest(test.TestCase):
@test_util.run_deprecated_v1
def test_real_tensor_doesnt_raise(self):
x = ops.convert_to_tensor([0., 2, 3])
with self.cached_session():
# Should not raise.
linear_operator_util.assert_zero_imag_part(x, message="ABC123").run()
# Should not raise.
self.evaluate(
linear_operator_util.assert_zero_imag_part(x, message="ABC123"))
@test_util.run_deprecated_v1
def test_complex_tensor_with_imag_zero_doesnt_raise(self):
x = ops.convert_to_tensor([1., 0, 3])
y = ops.convert_to_tensor([0., 0, 0])
z = math_ops.complex(x, y)
with self.cached_session():
# Should not raise.
linear_operator_util.assert_zero_imag_part(z, message="ABC123").run()
# Should not raise.
self.evaluate(
linear_operator_util.assert_zero_imag_part(z, message="ABC123"))
def test_complex_tensor_with_nonzero_imag_raises(self):
x = ops.convert_to_tensor([1., 2, 0])
y = ops.convert_to_tensor([1., 2, 0])
z = math_ops.complex(x, y)
with self.cached_session():
with self.assertRaisesOpError("ABC123"):
linear_operator_util.assert_zero_imag_part(z, message="ABC123").run()
with self.assertRaisesOpError("ABC123"):
self.evaluate(
linear_operator_util.assert_zero_imag_part(z, message="ABC123"))
class AssertNoEntriesWithModulusZeroTest(test.TestCase):
@test_util.run_deprecated_v1
def test_nonzero_real_tensor_doesnt_raise(self):
x = ops.convert_to_tensor([1., 2, 3])
with self.cached_session():
# Should not raise.
linear_operator_util.assert_no_entries_with_modulus_zero(
x, message="ABC123").run()
# Should not raise.
self.evaluate(
linear_operator_util.assert_no_entries_with_modulus_zero(
x, message="ABC123"))
@test_util.run_deprecated_v1
def test_nonzero_complex_tensor_doesnt_raise(self):
x = ops.convert_to_tensor([1., 0, 3])
y = ops.convert_to_tensor([1., 2, 0])
z = math_ops.complex(x, y)
with self.cached_session():
# Should not raise.
linear_operator_util.assert_no_entries_with_modulus_zero(
z, message="ABC123").run()
# Should not raise.
self.evaluate(
linear_operator_util.assert_no_entries_with_modulus_zero(
z, message="ABC123"))
def test_zero_real_tensor_raises(self):
x = ops.convert_to_tensor([1., 0, 3])
with self.cached_session():
with self.assertRaisesOpError("ABC123"):
linear_operator_util.assert_no_entries_with_modulus_zero(
x, message="ABC123").run()
with self.assertRaisesOpError("ABC123"):
self.evaluate(
linear_operator_util.assert_no_entries_with_modulus_zero(
x, message="ABC123"))
def test_zero_complex_tensor_raises(self):
x = ops.convert_to_tensor([1., 2, 0])
y = ops.convert_to_tensor([1., 2, 0])
z = math_ops.complex(x, y)
with self.cached_session():
with self.assertRaisesOpError("ABC123"):
linear_operator_util.assert_no_entries_with_modulus_zero(
z, message="ABC123").run()
with self.assertRaisesOpError("ABC123"):
self.evaluate(
linear_operator_util.assert_no_entries_with_modulus_zero(
z, message="ABC123"))
class BroadcastMatrixBatchDimsTest(test.TestCase):
@ -107,10 +101,8 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
tensor, = linear_operator_util.broadcast_matrix_batch_dims([arr])
self.assertTrue(isinstance(tensor, ops.Tensor))
with self.cached_session():
self.assertAllClose(arr, self.evaluate(tensor))
self.assertAllClose(arr, self.evaluate(tensor))
@test_util.run_deprecated_v1
def test_static_dims_broadcast(self):
# x.batch_shape = [3, 1, 2]
# y.batch_shape = [4, 1]
@ -123,12 +115,11 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
with self.cached_session() as sess:
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
def test_static_dims_broadcast_second_arg_higher_rank(self):
# x.batch_shape = [1, 2]
@ -142,14 +133,12 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x, y])
with self.cached_session() as sess:
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
self.assertAllEqual(x_bc_expected.shape, x_bc.get_shape())
self.assertAllEqual(y_bc_expected.shape, y_bc.get_shape())
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
@test_util.run_deprecated_v1
def test_dynamic_dims_broadcast_32bit(self):
# x.batch_shape = [3, 1, 2]
# y.batch_shape = [4, 1]
@ -160,17 +149,15 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc_expected = x + batch_of_zeros
y_bc_expected = y + batch_of_zeros
x_ph = array_ops.placeholder(dtypes.float32)
y_ph = array_ops.placeholder(dtypes.float32)
x_ph = array_ops.placeholder_with_default(x, shape=None)
y_ph = array_ops.placeholder_with_default(y, shape=None)
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
with self.cached_session() as sess:
x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
@test_util.run_deprecated_v1
def test_dynamic_dims_broadcast_32bit_second_arg_higher_rank(self):
# x.batch_shape = [1, 2]
# y.batch_shape = [3, 4, 1]
@ -181,15 +168,14 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
x_bc_expected = x + batch_of_zeros
y_bc_expected = y + batch_of_zeros
x_ph = array_ops.placeholder(dtypes.float32)
y_ph = array_ops.placeholder(dtypes.float32)
x_ph = array_ops.placeholder_with_default(x, shape=None)
y_ph = array_ops.placeholder_with_default(y, shape=None)
x_bc, y_bc = linear_operator_util.broadcast_matrix_batch_dims([x_ph, y_ph])
with self.cached_session() as sess:
x_bc_, y_bc_ = sess.run([x_bc, y_bc], feed_dict={x_ph: x, y_ph: y})
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
x_bc_, y_bc_ = self.evaluate([x_bc, y_bc])
self.assertAllClose(x_bc_expected, x_bc_)
self.assertAllClose(y_bc_expected, y_bc_)
def test_less_than_two_dims_raises_static(self):
x = rng.rand(3)
@ -204,20 +190,17 @@ class BroadcastMatrixBatchDimsTest(test.TestCase):
class CholeskySolveWithBroadcastTest(test.TestCase):
@test_util.run_deprecated_v1
def test_static_dims_broadcast(self):
# batch_shape = [2]
chol = rng.rand(3, 3)
rhs = rng.rand(2, 3, 7)
chol_broadcast = chol + np.zeros((2, 1, 1))
with self.cached_session():
result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
expected = linalg_ops.cholesky_solve(chol_broadcast, rhs)
self.assertAllClose(expected.eval(), self.evaluate(result))
result = linear_operator_util.cholesky_solve_with_broadcast(chol, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
expected = linalg_ops.cholesky_solve(chol_broadcast, rhs)
self.assertAllClose(*self.evaluate([expected, result]))
@test_util.run_deprecated_v1
def test_dynamic_dims_broadcast_64bit(self):
# batch_shape = [2, 2]
chol = rng.rand(2, 3, 3)
@ -225,40 +208,29 @@ class CholeskySolveWithBroadcastTest(test.TestCase):
chol_broadcast = chol + np.zeros((2, 2, 1, 1))
rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))
chol_ph = array_ops.placeholder(dtypes.float64)
rhs_ph = array_ops.placeholder(dtypes.float64)
chol_ph = array_ops.placeholder_with_default(chol, shape=None)
rhs_ph = array_ops.placeholder_with_default(rhs, shape=None)
with self.cached_session() as sess:
result, expected = sess.run(
[
linear_operator_util.cholesky_solve_with_broadcast(
chol_ph, rhs_ph),
linalg_ops.cholesky_solve(chol_broadcast, rhs_broadcast)
],
feed_dict={
chol_ph: chol,
rhs_ph: rhs,
})
self.assertAllClose(expected, result)
result, expected = self.evaluate([
linear_operator_util.cholesky_solve_with_broadcast(chol_ph, rhs_ph),
linalg_ops.cholesky_solve(chol_broadcast, rhs_broadcast)
])
self.assertAllClose(expected, result)
class MatrixSolveWithBroadcastTest(test.TestCase):
@test_util.run_deprecated_v1
def test_static_dims_broadcast_matrix_has_extra_dims(self):
# batch_shape = [2]
matrix = rng.rand(2, 3, 3)
rhs = rng.rand(3, 7)
rhs_broadcast = rhs + np.zeros((2, 1, 1))
with self.cached_session():
result = linear_operator_util.matrix_solve_with_broadcast(
matrix, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
expected = linalg_ops.matrix_solve(matrix, rhs_broadcast)
self.assertAllClose(expected.eval(), self.evaluate(result))
result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
expected = linalg_ops.matrix_solve(matrix, rhs_broadcast)
self.assertAllClose(*self.evaluate([expected, result]))
@test_util.run_deprecated_v1
def test_static_dims_broadcast_rhs_has_extra_dims(self):
# Since the second arg has extra dims, and the domain dim of the first arg
# is larger than the number of linear equations, code will "flip" the extra
@ -271,13 +243,11 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
rhs = rng.rand(2, 3, 2)
matrix_broadcast = matrix + np.zeros((2, 1, 1))
with self.cached_session():
result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
self.assertAllEqual((2, 3, 2), result.get_shape())
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
self.assertAllClose(expected.eval(), self.evaluate(result))
result = linear_operator_util.matrix_solve_with_broadcast(matrix, rhs)
self.assertAllEqual((2, 3, 2), result.get_shape())
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
self.assertAllClose(*self.evaluate([expected, result]))
@test_util.run_deprecated_v1
def test_static_dims_broadcast_rhs_has_extra_dims_dynamic(self):
# Since the second arg has extra dims, and the domain dim of the first arg
# is larger than the number of linear equations, code will "flip" the extra
@ -290,22 +260,14 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
rhs = rng.rand(2, 3, 2)
matrix_broadcast = matrix + np.zeros((2, 1, 1))
matrix_ph = array_ops.placeholder(dtypes.float64, shape=[None, None])
rhs_ph = array_ops.placeholder(dtypes.float64, shape=[None, None, None])
matrix_ph = array_ops.placeholder_with_default(matrix, shape=[None, None])
rhs_ph = array_ops.placeholder_with_default(rhs, shape=[None, None, None])
with self.cached_session():
result = linear_operator_util.matrix_solve_with_broadcast(matrix_ph,
rhs_ph)
self.assertAllEqual(3, result.shape.ndims)
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
self.assertAllClose(
self.evaluate(expected),
result.eval(feed_dict={
matrix_ph: matrix,
rhs_ph: rhs
}))
result = linear_operator_util.matrix_solve_with_broadcast(matrix_ph, rhs_ph)
self.assertAllEqual(3, result.shape.ndims)
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs)
self.assertAllClose(*self.evaluate([expected, result]))
@test_util.run_deprecated_v1
def test_static_dims_broadcast_rhs_has_extra_dims_and_adjoint(self):
# Since the second arg has extra dims, and the domain dim of the first arg
# is larger than the number of linear equations, code will "flip" the extra
@ -318,14 +280,12 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
rhs = rng.rand(2, 3, 2)
matrix_broadcast = matrix + np.zeros((2, 1, 1))
with self.cached_session():
result = linear_operator_util.matrix_solve_with_broadcast(
matrix, rhs, adjoint=True)
self.assertAllEqual((2, 3, 2), result.get_shape())
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs, adjoint=True)
self.assertAllClose(expected.eval(), self.evaluate(result))
result = linear_operator_util.matrix_solve_with_broadcast(
matrix, rhs, adjoint=True)
self.assertAllEqual((2, 3, 2), result.get_shape())
expected = linalg_ops.matrix_solve(matrix_broadcast, rhs, adjoint=True)
self.assertAllClose(*self.evaluate([expected, result]))
@test_util.run_deprecated_v1
def test_dynamic_dims_broadcast_64bit(self):
# batch_shape = [2, 2]
matrix = rng.rand(2, 3, 3)
@ -333,40 +293,30 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
matrix_broadcast = matrix + np.zeros((2, 2, 1, 1))
rhs_broadcast = rhs + np.zeros((2, 2, 1, 1))
matrix_ph = array_ops.placeholder(dtypes.float64)
rhs_ph = array_ops.placeholder(dtypes.float64)
matrix_ph = array_ops.placeholder_with_default(matrix, shape=None)
rhs_ph = array_ops.placeholder_with_default(rhs, shape=None)
with self.cached_session() as sess:
result, expected = sess.run(
[
linear_operator_util.matrix_solve_with_broadcast(
matrix_ph, rhs_ph),
linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
],
feed_dict={
matrix_ph: matrix,
rhs_ph: rhs,
})
self.assertAllClose(expected, result)
result, expected = self.evaluate([
linear_operator_util.matrix_solve_with_broadcast(matrix_ph, rhs_ph),
linalg_ops.matrix_solve(matrix_broadcast, rhs_broadcast)
])
self.assertAllClose(expected, result)
class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
@test_util.run_deprecated_v1
def test_static_dims_broadcast_matrix_has_extra_dims(self):
# batch_shape = [2]
matrix = rng.rand(2, 3, 3)
rhs = rng.rand(3, 7)
rhs_broadcast = rhs + np.zeros((2, 1, 1))
with self.cached_session():
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
self.assertAllClose(expected.eval(), self.evaluate(result))
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs)
self.assertAllEqual((2, 3, 7), result.get_shape())
expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
self.assertAllClose(*self.evaluate([expected, result]))
@test_util.run_deprecated_v1
def test_static_dims_broadcast_rhs_has_extra_dims(self):
# Since the second arg has extra dims, and the domain dim of the first arg
# is larger than the number of linear equations, code will "flip" the extra
@ -379,14 +329,12 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
rhs = rng.rand(2, 3, 2)
matrix_broadcast = matrix + np.zeros((2, 1, 1))
with self.cached_session():
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs)
self.assertAllEqual((2, 3, 2), result.get_shape())
expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs)
self.assertAllClose(expected.eval(), self.evaluate(result))
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs)
self.assertAllEqual((2, 3, 2), result.get_shape())
expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs)
self.assertAllClose(*self.evaluate([expected, result]))
@test_util.run_deprecated_v1
def test_static_dims_broadcast_rhs_has_extra_dims_and_adjoint(self):
# Since the second arg has extra dims, and the domain dim of the first arg
# is larger than the number of linear equations, code will "flip" the extra
@ -399,36 +347,28 @@ class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
rhs = rng.rand(2, 3, 2)
matrix_broadcast = matrix + np.zeros((2, 1, 1))
with self.cached_session():
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs, adjoint=True)
self.assertAllEqual((2, 3, 2), result.get_shape())
expected = linalg_ops.matrix_triangular_solve(
matrix_broadcast, rhs, adjoint=True)
self.assertAllClose(expected.eval(), self.evaluate(result))
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs, adjoint=True)
self.assertAllEqual((2, 3, 2), result.get_shape())
expected = linalg_ops.matrix_triangular_solve(
matrix_broadcast, rhs, adjoint=True)
self.assertAllClose(*self.evaluate([expected, result]))
@test_util.run_deprecated_v1
def test_dynamic_dims_broadcast_64bit(self):
# batch_shape = [2]
matrix = rng.rand(2, 3, 3)
rhs = rng.rand(3, 7)
rhs_broadcast = rhs + np.zeros((2, 1, 1))
matrix_ph = array_ops.placeholder(dtypes.float64)
rhs_ph = array_ops.placeholder(dtypes.float64)
matrix_ph = array_ops.placeholder_with_default(matrix, shape=None)
rhs_ph = array_ops.placeholder_with_default(rhs, shape=None)
with self.cached_session() as sess:
result, expected = sess.run(
[
linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix_ph, rhs_ph),
linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
],
feed_dict={
matrix_ph: matrix,
rhs_ph: rhs,
})
self.assertAllClose(expected, result)
result, expected = self.evaluate([
linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix_ph, rhs_ph),
linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
])
self.assertAllClose(expected, result)
class DomainDimensionStubOperator(object):
@ -442,22 +382,21 @@ class DomainDimensionStubOperator(object):
class AssertCompatibleMatrixDimensionsTest(test.TestCase):
@test_util.run_deprecated_v1
def test_compatible_dimensions_do_not_raise(self):
with self.cached_session():
x = ops.convert_to_tensor(rng.rand(2, 3, 4))
operator = DomainDimensionStubOperator(3)
# Should not raise
linear_operator_util.assert_compatible_matrix_dimensions(
operator, x).run() # pyformat: disable
x = ops.convert_to_tensor(rng.rand(2, 3, 4))
operator = DomainDimensionStubOperator(3)
# Should not raise
self.evaluate(
linear_operator_util.assert_compatible_matrix_dimensions(operator, x))
def test_incompatible_dimensions_raise(self):
with self.cached_session():
x = ops.convert_to_tensor(rng.rand(2, 4, 4))
operator = DomainDimensionStubOperator(3)
with self.assertRaisesOpError("Incompatible matrix dimensions"):
linear_operator_util.assert_compatible_matrix_dimensions(
operator, x).run() # pyformat: disable
x = ops.convert_to_tensor(rng.rand(2, 4, 4))
operator = DomainDimensionStubOperator(3)
# pylint: disable=g-error-prone-assert-raises
with self.assertRaisesOpError("Dimensions are not compatible"):
self.evaluate(
linear_operator_util.assert_compatible_matrix_dimensions(operator, x))
# pylint: enable=g-error-prone-assert-raises
class DummyOperatorWithHint(object):

View File

@ -22,6 +22,7 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variables as variables_module
from tensorflow.python.ops.linalg import linalg as linalg_lib
from tensorflow.python.ops.linalg import linear_operator_test_util
from tensorflow.python.platform import test
@ -30,6 +31,7 @@ from tensorflow.python.platform import test
rng = np.random.RandomState(2016)
@test_util.run_all_in_graph_and_eager_modes
class LinearOperatorZerosTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
@ -75,11 +77,10 @@ class LinearOperatorZerosTest(
operator = linalg_lib.LinearOperatorZeros(num_rows=2)
operator.assert_non_singular()
@test_util.run_deprecated_v1
def test_assert_self_adjoint(self):
with self.cached_session():
operator = linalg_lib.LinearOperatorZeros(num_rows=2)
operator.assert_self_adjoint().run() # Should not fail
self.evaluate(operator.assert_self_adjoint()) # Should not fail
def test_non_scalar_num_rows_raises_static(self):
with self.assertRaisesRegexp(ValueError, "must be a 0-D Tensor"):
@ -111,46 +112,37 @@ class LinearOperatorZerosTest(
with self.assertRaisesRegexp(ValueError, "must be non-negative"):
linalg_lib.LinearOperatorZeros(num_rows=2, batch_shape=[-2])
@test_util.run_deprecated_v1
def test_non_scalar_num_rows_raises_dynamic(self):
with self.cached_session():
num_rows = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorZeros(
num_rows, assert_proper_shapes=True)
with self.assertRaisesOpError("must be a 0-D Tensor"):
operator.to_dense().eval(feed_dict={num_rows: [2]})
num_rows = array_ops.placeholder_with_default([2], shape=None)
with self.assertRaisesError("must be a 0-D Tensor"):
operator = linalg_lib.LinearOperatorZeros(
num_rows, assert_proper_shapes=True)
self.evaluate(operator.to_dense())
@test_util.run_deprecated_v1
def test_negative_num_rows_raises_dynamic(self):
with self.cached_session():
n = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorZeros(
num_rows=n, assert_proper_shapes=True)
with self.assertRaisesOpError("must be non-negative"):
operator.to_dense().eval(feed_dict={n: -2})
n = array_ops.placeholder_with_default(-2, shape=None)
with self.assertRaisesError("must be non-negative"):
operator = linalg_lib.LinearOperatorZeros(
num_rows=n, assert_proper_shapes=True)
self.evaluate(operator.to_dense())
operator = linalg_lib.LinearOperatorZeros(
num_rows=2, num_columns=n, assert_proper_shapes=True)
with self.assertRaisesOpError("must be non-negative"):
operator.to_dense().eval(feed_dict={n: -2})
@test_util.run_deprecated_v1
def test_non_1d_batch_shape_raises_dynamic(self):
with self.cached_session():
batch_shape = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorZeros(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
with self.assertRaisesOpError("must be a 1-D"):
operator.to_dense().eval(feed_dict={batch_shape: 2})
batch_shape = array_ops.placeholder_with_default(2, shape=None)
with self.assertRaisesError("must be a 1-D"):
operator = linalg_lib.LinearOperatorZeros(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
self.evaluate(operator.to_dense())
@test_util.run_deprecated_v1
def test_negative_batch_shape_raises_dynamic(self):
with self.cached_session():
batch_shape = array_ops.placeholder(dtypes.int32)
operator = linalg_lib.LinearOperatorZeros(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
with self.assertRaisesOpError("must be non-negative"):
operator.to_dense().eval(feed_dict={batch_shape: [-2]})
batch_shape = array_ops.placeholder_with_default([-2], shape=None)
with self.assertRaisesError("must be non-negative"):
operator = linalg_lib.LinearOperatorZeros(
num_rows=2, batch_shape=batch_shape, assert_proper_shapes=True)
self.evaluate(operator.to_dense())
def test_wrong_matrix_dimensions_raises_static(self):
operator = linalg_lib.LinearOperatorZeros(num_rows=2)
@ -158,17 +150,15 @@ class LinearOperatorZerosTest(
with self.assertRaisesRegexp(ValueError, "Dimensions.*not compatible"):
operator.matmul(x)
@test_util.run_deprecated_v1
def test_wrong_matrix_dimensions_raises_dynamic(self):
num_rows = array_ops.placeholder(dtypes.int32)
x = array_ops.placeholder(dtypes.float32)
num_rows = array_ops.placeholder_with_default(2, shape=None)
x = array_ops.placeholder_with_default(rng.rand(3, 3), shape=None)
with self.cached_session():
operator = linalg_lib.LinearOperatorZeros(
num_rows, assert_proper_shapes=True)
y = operator.matmul(x)
with self.assertRaisesOpError("Incompatible.*dimensions"):
y.eval(feed_dict={num_rows: 2, x: rng.rand(3, 3)})
with self.assertRaisesError("Dimensions.*not.compatible"):
operator = linalg_lib.LinearOperatorZeros(
num_rows, assert_proper_shapes=True, dtype=dtypes.float64)
self.evaluate(operator.matmul(x))
def test_is_x_flags(self):
# The is_x flags are by default all True.
@ -188,7 +178,20 @@ class LinearOperatorZerosTest(
operator2.matmul(operator1),
linalg_lib.LinearOperatorZeros))
def test_ref_type_shape_args_raises(self):
with self.assertRaisesRegexp(TypeError, "num_rows.cannot.be.reference"):
linalg_lib.LinearOperatorZeros(num_rows=variables_module.Variable(2))
with self.assertRaisesRegexp(TypeError, "num_columns.cannot.be.reference"):
linalg_lib.LinearOperatorZeros(
num_rows=2, num_columns=variables_module.Variable(3))
with self.assertRaisesRegexp(TypeError, "batch_shape.cannot.be.reference"):
linalg_lib.LinearOperatorZeros(
num_rows=2, batch_shape=variables_module.Variable([2]))
@test_util.run_all_in_graph_and_eager_modes
class LinearOperatorZerosNotSquareTest(
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):

View File

@ -25,6 +25,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linalg_impl as linalg
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_util
from tensorflow.python.util.tf_export import tf_export
__all__ = ["LinearOperatorHouseholder",]
@ -123,7 +124,7 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
"""
with ops.name_scope(name, values=[reflection_axis]):
self._reflection_axis = ops.convert_to_tensor(
self._reflection_axis = linear_operator_util.convert_nonref_to_tensor(
reflection_axis, name="reflection_axis")
self._check_reflection_axis(self._reflection_axis)
@ -194,9 +195,10 @@ class LinearOperatorHouseholder(linear_operator.LinearOperator):
# Note that because this is a reflection, it lies in O(n) (for real vector
# spaces) or U(n) (for complex vector spaces), and thus is its own adjoint.
reflection_axis = ops.convert_to_tensor(self.reflection_axis)
x = linalg.adjoint(x) if adjoint_arg else x
normalized_axis = self.reflection_axis / linalg.norm(
self.reflection_axis, axis=-1, keepdims=True)
normalized_axis = reflection_axis / linalg.norm(
reflection_axis, axis=-1, keepdims=True)
mat = normalized_axis[..., array_ops.newaxis]
x_dot_normalized_v = math_ops.matmul(mat, x, adjoint_a=True)

View File

@ -250,6 +250,7 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
negative.
ValueError: If any of the following is not `True`:
`{is_self_adjoint, is_non_singular, is_positive_definite}`.
TypeError: If `num_rows` or `batch_shape` is ref-type (e.g. Variable).
"""
dtype = dtype or dtypes.float32
self._assert_proper_shapes = assert_proper_shapes
@ -273,6 +274,9 @@ class LinearOperatorIdentity(BaseLinearOperatorIdentity):
is_square=is_square,
name=name)
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape")
self._num_rows = linear_operator_util.shape_tensor(
num_rows, name="num_rows")
self._num_rows_static = tensor_util.constant_value(self._num_rows)
@ -589,7 +593,8 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
self._assert_proper_shapes = assert_proper_shapes
with ops.name_scope(name, values=[multiplier, num_rows]):
self._multiplier = ops.convert_to_tensor(multiplier, name="multiplier")
self._multiplier = linear_operator_util.convert_nonref_to_tensor(
multiplier, name="multiplier")
# Check and auto-set hints.
if not self._multiplier.dtype.is_complex:
@ -601,20 +606,16 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
if not is_square:
raise ValueError("A ScaledIdentity operator is always square.")
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
super(LinearOperatorScaledIdentity, self).__init__(
dtype=self._multiplier.dtype,
dtype=self._multiplier.dtype.base_dtype,
is_non_singular=is_non_singular,
is_self_adjoint=is_self_adjoint,
is_positive_definite=is_positive_definite,
is_square=is_square,
name=name)
# Shape [B1,...Bb, 1, 1]
self._multiplier_matrix = array_ops.expand_dims(
array_ops.expand_dims(self.multiplier, -1), -1)
self._multiplier_matrix_conj = math_ops.conj(self._multiplier_matrix)
self._abs_multiplier = math_ops.abs(self.multiplier)
self._num_rows = linear_operator_util.shape_tensor(
num_rows, name="num_rows")
self._num_rows_static = tensor_util.constant_value(self._num_rows)
@ -652,34 +653,34 @@ class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity):
imag_multiplier,
message="LinearOperator was not self-adjoint")
def _make_multiplier_matrix(self, conjugate=False):
# Shape [B1,...Bb, 1, 1]
multiplier_matrix = array_ops.expand_dims(
array_ops.expand_dims(self.multiplier, -1), -1)
if conjugate:
multiplier_matrix = math_ops.conj(multiplier_matrix)
return multiplier_matrix
def _matmul(self, x, adjoint=False, adjoint_arg=False):
x = linalg.adjoint(x) if adjoint_arg else x
if adjoint:
matrix = self._multiplier_matrix_conj
else:
matrix = self._multiplier_matrix
if self._assert_proper_shapes:
aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x)
x = control_flow_ops.with_dependencies([aps], x)
return x * matrix
return x * self._make_multiplier_matrix(conjugate=adjoint)
def _determinant(self):
return self.multiplier**self._num_rows_cast_to_dtype
def _log_abs_determinant(self):
return self._num_rows_cast_to_real_dtype * math_ops.log(
self._abs_multiplier)
math_ops.abs(self.multiplier))
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
if adjoint:
matrix = self._multiplier_matrix_conj
else:
matrix = self._multiplier_matrix
if self._assert_proper_shapes:
aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs)
rhs = control_flow_ops.with_dependencies([aps], rhs)
return rhs / matrix
return rhs / self._make_multiplier_matrix(conjugate=adjoint)
def _trace(self):
# Get Tensor of all ones of same shape as self.batch_shape.

View File

@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@ -145,10 +144,9 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
is_square = True
with ops.name_scope(name, values=[tril]):
self._tril = ops.convert_to_tensor(tril, name="tril")
self._tril = linear_operator_util.convert_nonref_to_tensor(tril,
name="tril")
self._check_tril(self._tril)
self._tril = array_ops.matrix_band_part(tril, -1, 0)
self._diag = array_ops.matrix_diag_part(self._tril)
super(LinearOperatorLowerTriangular, self).__init__(
dtype=self._tril.dtype,
@ -161,24 +159,20 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
def _check_tril(self, tril):
"""Static check of the `tril` argument."""
allowed_dtypes = [
dtypes.float16,
dtypes.float32,
dtypes.float64,
dtypes.complex64,
dtypes.complex128,
]
dtype = tril.dtype
if dtype not in allowed_dtypes:
raise TypeError(
"Argument tril must have dtype in %s. Found: %s"
% (allowed_dtypes, dtype))
if tril.get_shape().ndims is not None and tril.get_shape().ndims < 2:
raise ValueError(
"Argument tril must have at least 2 dimensions. Found: %s"
% tril)
def _get_tril(self):
"""Gets the `tril` kwarg, with upper part zero-d out."""
return array_ops.matrix_band_part(self._tril, -1, 0)
def _get_diag(self):
"""Gets the diagonal part of `tril` kwarg."""
return array_ops.matrix_diag_part(self._tril)
def _shape(self):
return self._tril.get_shape()
@ -187,27 +181,24 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
def _assert_non_singular(self):
return linear_operator_util.assert_no_entries_with_modulus_zero(
self._diag,
self._get_diag(),
message="Singular operator: Diagonal contained zero values.")
def _matmul(self, x, adjoint=False, adjoint_arg=False):
return math_ops.matmul(
self._tril, x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
self._get_tril(), x, adjoint_a=adjoint, adjoint_b=adjoint_arg)
def _determinant(self):
return math_ops.reduce_prod(self._diag, axis=[-1])
return math_ops.reduce_prod(self._get_diag(), axis=[-1])
def _log_abs_determinant(self):
return math_ops.reduce_sum(
math_ops.log(math_ops.abs(self._diag)), axis=[-1])
math_ops.log(math_ops.abs(self._get_diag())), axis=[-1])
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
return linear_operator_util.matrix_triangular_solve_with_broadcast(
self._tril, rhs, lower=True, adjoint=adjoint)
self._get_tril(), rhs, lower=True, adjoint=adjoint)
def _to_dense(self):
return self._tril
def _add_to_tensor(self, x):
return self._tril + x
return self._get_tril()

View File

@ -24,6 +24,7 @@ import numpy as np
import six
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
@ -51,6 +52,15 @@ class OperatorShapesInfo(object):
self.__dict__.update(kwargs)
class CheckTapeSafeSkipOptions(object):
# Skip checking this particular method.
DETERMINANT = "determinant"
DIAG_PART = "diag_part"
LOG_ABS_DETERMINANT = "log_abs_determinant"
TRACE = "trace"
@six.add_metaclass(abc.ABCMeta) # pylint: disable=no-init
class LinearOperatorDerivedClassTest(test.TestCase):
"""Tests for derived classes.
@ -174,18 +184,35 @@ class LinearOperatorDerivedClassTest(test.TestCase):
# To skip "test_foo", add "foo" to this list.
return []
def check_tape_safe(self, operator):
"""Check gradients are not None w.r.t. Variables.
def assertRaisesError(self, msg):
"""assertRaisesRegexp or OpError, depending on context.executing_eagerly."""
if context.executing_eagerly():
return self.assertRaisesRegexp(Exception, msg)
return self.assertRaisesOpError(msg)
def check_tape_safe(self, operator, skip_options=None):
"""Check gradients are not None w.r.t. operator.variables.
Meant to be called from the derived class.
This ensures grads are not w.r.t every variable in operator.variables. If
more fine-grained testing is needed, a custom test should be written.
Args:
operator: LinearOperator. Exact checks done will depend on hints.
skip_options: Optional list of CheckTapeSafeSkipOptions.
Makes this test skip particular checks.
"""
skip_options = skip_options or []
if not operator.variables:
raise AssertionError("`operator.variables` was empty")
def _assert_not_none(iterable):
for item in iterable:
self.assertIsNotNone(item)
# Tape tests that can be run on every operator below.
with backprop.GradientTape() as tape:
_assert_not_none(tape.gradient(operator.to_dense(), operator.variables))
@ -193,23 +220,30 @@ class LinearOperatorDerivedClassTest(test.TestCase):
_assert_not_none(
tape.gradient(operator.adjoint().to_dense(), operator.variables))
x = array_ops.ones(shape=operator.H.shape_tensor()[:-1])
x = math_ops.cast(
array_ops.ones(shape=operator.H.shape_tensor()[:-1]), operator.dtype)
with backprop.GradientTape() as tape:
_assert_not_none(tape.gradient(operator.matvec(x), operator.variables))
# Tests for square, but possibly non-singular operators below.
if not operator.is_square:
return
with backprop.GradientTape() as tape:
_assert_not_none(
tape.gradient(operator.determinant(), operator.variables))
for option in [
CheckTapeSafeSkipOptions.DETERMINANT,
CheckTapeSafeSkipOptions.LOG_ABS_DETERMINANT,
CheckTapeSafeSkipOptions.DIAG_PART,
CheckTapeSafeSkipOptions.TRACE,
]:
with backprop.GradientTape() as tape:
if option not in skip_options:
_assert_not_none(
tape.gradient(getattr(operator, option)(), operator.variables))
with backprop.GradientTape() as tape:
_assert_not_none(tape.gradient(operator.diag_part(), operator.variables))
with backprop.GradientTape() as tape:
_assert_not_none(tape.gradient(operator.trace(), operator.variables))
# Tests for non-singular operators below.
if operator.is_non_singular is False: # pylint: disable=g-bool-id-comparison
return
with backprop.GradientTape() as tape:
_assert_not_none(
@ -218,6 +252,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
with backprop.GradientTape() as tape:
_assert_not_none(tape.gradient(operator.solvevec(x), operator.variables))
# Tests for SPD operators below.
if not (operator.is_self_adjoint and operator.is_positive_definite):
return

View File

@ -157,6 +157,12 @@ def is_ref(x):
hasattr(x, "shape")))
def assert_not_ref_type(x, arg_name):
if is_ref(x):
raise TypeError(
"Argument %s cannot be reference type. Found: %s" % (arg_name, type(x)))
################################################################################
# Asserts.
################################################################################
@ -223,7 +229,9 @@ def assert_compatible_matrix_dimensions(operator, x):
assert_same_dd = check_ops.assert_equal(
array_ops.shape(x)[-2],
operator.domain_dimension_tensor(),
message=("Incompatible matrix dimensions. "
# This error message made to look similar to error raised by static check
# in the base class.
message=("Dimensions are not compatible. "
"shape[-2] of argument to be the same as this operator"))
return assert_same_dd

View File

@ -196,6 +196,10 @@ class LinearOperatorZeros(linear_operator.LinearOperator):
is_square=is_square,
name=name)
linear_operator_util.assert_not_ref_type(num_rows, "num_rows")
linear_operator_util.assert_not_ref_type(num_columns, "num_columns")
linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape")
self._num_rows = linear_operator_util.shape_tensor(
num_rows, name="num_rows")
self._num_rows_static = tensor_util.constant_value(self._num_rows)