Allow blockwise linear operators (tf.linalg.LinearOperatorBlockDiag and tf.linalg.LinearOperatorBlockLowerTriangular) to operate on/emit lists of tensors corresponding to blocks.
PiperOrigin-RevId: 300212672 Change-Id: I2801decb6ac1c63f319cf70c85a2b2e67a55c998
This commit is contained in:
parent
6be0f1fe0b
commit
d34ee8ee57
@ -81,7 +81,7 @@ cuda_py_test(
|
||||
name = "linear_operator_block_diag_test",
|
||||
size = "medium",
|
||||
srcs = ["linear_operator_block_diag_test.py"],
|
||||
shard_count = 6,
|
||||
shard_count = 8,
|
||||
tags = [
|
||||
"noasan",
|
||||
"optonly",
|
||||
@ -103,7 +103,7 @@ cuda_py_test(
|
||||
name = "linear_operator_block_lower_triangular_test",
|
||||
size = "medium",
|
||||
srcs = ["linear_operator_block_lower_triangular_test.py"],
|
||||
shard_count = 6,
|
||||
shard_count = 8,
|
||||
tags = [
|
||||
"noasan",
|
||||
"optonly",
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -82,6 +83,10 @@ class SquareLinearOperatorBlockDiagTest(
|
||||
shape_info((2, 1, 5, 5), blocks=[(2, 1, 2, 2), (1, 3, 3)]),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def use_blockwise_arg():
|
||||
return True
|
||||
|
||||
def operator_and_matrix(
|
||||
self, shape_info, dtype, use_placeholder,
|
||||
ensure_self_adjoint_and_pd=False):
|
||||
@ -275,6 +280,20 @@ class SquareLinearOperatorBlockDiagTest(
|
||||
with self.assertRaisesRegexp(ValueError, "non-empty"):
|
||||
block_diag.LinearOperatorBlockDiag([])
|
||||
|
||||
def test_incompatible_input_blocks_raises(self):
|
||||
matrix_1 = array_ops.placeholder_with_default(rng.rand(4, 4), shape=None)
|
||||
matrix_2 = array_ops.placeholder_with_default(rng.rand(3, 3), shape=None)
|
||||
operators = [
|
||||
linalg.LinearOperatorFullMatrix(matrix_1, is_square=True),
|
||||
linalg.LinearOperatorFullMatrix(matrix_2, is_square=True)
|
||||
]
|
||||
operator = block_diag.LinearOperatorBlockDiag(operators)
|
||||
x = np.random.rand(2, 4, 5).tolist()
|
||||
msg = ("dimension does not match" if context.executing_eagerly()
|
||||
else "input structure is ambiguous")
|
||||
with self.assertRaisesRegexp(ValueError, msg):
|
||||
operator.matmul(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
linear_operator_test_util.add_tests(SquareLinearOperatorBlockDiagTest)
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.eager import backprop
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
@ -68,6 +69,10 @@ class SquareLinearOperatorBlockLowerTriangularTest(
|
||||
self._rtol[dtypes.complex64] = 1e-5
|
||||
super(SquareLinearOperatorBlockLowerTriangularTest, self).setUp()
|
||||
|
||||
@staticmethod
|
||||
def use_blockwise_arg():
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def skip_these_tests():
|
||||
# Skipping since `LinearOperatorBlockLowerTriangular` is in general not
|
||||
@ -267,6 +272,23 @@ class SquareLinearOperatorBlockLowerTriangularTest(
|
||||
with self.assertRaisesRegexp(ValueError, "must be equal"):
|
||||
block_lower_triangular.LinearOperatorBlockLowerTriangular(operators)
|
||||
|
||||
def test_incompatible_input_blocks_raises(self):
|
||||
matrix_1 = array_ops.placeholder_with_default(rng.rand(4, 4), shape=None)
|
||||
matrix_2 = array_ops.placeholder_with_default(rng.rand(3, 4), shape=None)
|
||||
matrix_3 = array_ops.placeholder_with_default(rng.rand(3, 3), shape=None)
|
||||
operators = [
|
||||
[linalg.LinearOperatorFullMatrix(matrix_1, is_square=True)],
|
||||
[linalg.LinearOperatorFullMatrix(matrix_2),
|
||||
linalg.LinearOperatorFullMatrix(matrix_3, is_square=True)]
|
||||
]
|
||||
operator = block_lower_triangular.LinearOperatorBlockLowerTriangular(
|
||||
operators)
|
||||
x = np.random.rand(2, 4, 5).tolist()
|
||||
msg = ("dimension does not match" if context.executing_eagerly()
|
||||
else "input structure is ambiguous")
|
||||
with self.assertRaisesRegexp(ValueError, msg):
|
||||
operator.matmul(x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
linear_operator_test_util.add_tests(
|
||||
|
@ -21,6 +21,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
@ -344,5 +345,88 @@ class UseOperatorOrProvidedHintUnlessContradictingTest(test.TestCase,
|
||||
message="my error message")
|
||||
|
||||
|
||||
class BlockwiseTest(test.TestCase, parameterized.TestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("split_dim_1", [3, 3, 4], -1),
|
||||
("split_dim_2", [2, 5], -2),
|
||||
)
|
||||
def test_blockwise_input(self, op_dimension_values, split_dim):
|
||||
|
||||
op_dimensions = [
|
||||
tensor_shape.Dimension(v) for v in op_dimension_values]
|
||||
unknown_op_dimensions = [
|
||||
tensor_shape.Dimension(None) for _ in op_dimension_values]
|
||||
|
||||
batch_shape = [2, 1]
|
||||
arg_dim = 5
|
||||
if split_dim == -1:
|
||||
blockwise_arrays = [np.zeros(batch_shape + [arg_dim, d])
|
||||
for d in op_dimension_values]
|
||||
else:
|
||||
blockwise_arrays = [np.zeros(batch_shape + [d, arg_dim])
|
||||
for d in op_dimension_values]
|
||||
|
||||
blockwise_list = [block.tolist() for block in blockwise_arrays]
|
||||
blockwise_tensors = [ops.convert_to_tensor(block)
|
||||
for block in blockwise_arrays]
|
||||
blockwise_placeholders = [
|
||||
array_ops.placeholder_with_default(block, shape=None)
|
||||
for block in blockwise_arrays]
|
||||
|
||||
# Iterables of non-nested structures are always interpreted as blockwise.
|
||||
# The list of lists is interpreted as blockwise as well, regardless of
|
||||
# whether the operator dimensions are known, since the sizes of its elements
|
||||
# along `split_dim` are non-identical.
|
||||
for op_dims in [op_dimensions, unknown_op_dimensions]:
|
||||
for blockwise_inputs in [
|
||||
blockwise_arrays, blockwise_list,
|
||||
blockwise_tensors, blockwise_placeholders]:
|
||||
self.assertTrue(linear_operator_util.arg_is_blockwise(
|
||||
op_dims, blockwise_inputs, split_dim))
|
||||
|
||||
def test_non_blockwise_input(self):
|
||||
x = np.zeros((2, 3, 4, 6))
|
||||
x_tensor = ops.convert_to_tensor(x)
|
||||
x_placeholder = array_ops.placeholder_with_default(x, shape=None)
|
||||
x_list = x.tolist()
|
||||
|
||||
# For known and matching operator dimensions, interpret all as non-blockwise
|
||||
op_dimension_values = [2, 1, 3]
|
||||
op_dimensions = [tensor_shape.Dimension(d) for d in op_dimension_values]
|
||||
for inputs in [x, x_tensor, x_placeholder, x_list]:
|
||||
self.assertFalse(linear_operator_util.arg_is_blockwise(
|
||||
op_dimensions, inputs, -1))
|
||||
|
||||
# The input is still interpreted as non-blockwise for unknown operator
|
||||
# dimensions (`x_list` has an outermost dimension that does not matcn the
|
||||
# number of blocks, and the other inputs are not iterables).
|
||||
unknown_op_dimensions = [
|
||||
tensor_shape.Dimension(None) for _ in op_dimension_values]
|
||||
for inputs in [x, x_tensor, x_placeholder, x_list]:
|
||||
self.assertFalse(linear_operator_util.arg_is_blockwise(
|
||||
unknown_op_dimensions, inputs, -1))
|
||||
|
||||
def test_ambiguous_input_raises(self):
|
||||
x = np.zeros((3, 4, 2)).tolist()
|
||||
op_dimensions = [tensor_shape.Dimension(None) for _ in range(3)]
|
||||
|
||||
# Since the leftmost dimension of `x` is equal to the number of blocks, and
|
||||
# the operators have unknown dimension, the input is ambiguous.
|
||||
with self.assertRaisesRegexp(ValueError, "structure is ambiguous"):
|
||||
linear_operator_util.arg_is_blockwise(op_dimensions, x, -2)
|
||||
|
||||
def test_mismatched_input_raises(self):
|
||||
x = np.zeros((2, 3, 4, 6)).tolist()
|
||||
op_dimension_values = [4, 3]
|
||||
op_dimensions = [tensor_shape.Dimension(v) for v in op_dimension_values]
|
||||
|
||||
# The dimensions of the two operator-blocks sum to 7. `x` is a
|
||||
# two-element list; if interpreted blockwise, its corresponding dimensions
|
||||
# sum to 12 (=6*2). If not interpreted blockwise, its corresponding
|
||||
# dimension is 6. This is a mismatch.
|
||||
with self.assertRaisesRegexp(ValueError, "dimension does not match"):
|
||||
linear_operator_util.arg_is_blockwise(op_dimensions, x, -1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -664,7 +664,7 @@ class LinearOperator(module.Module):
|
||||
"""Transform [batch] vector `x` with left multiplication: `x --> Ax`.
|
||||
|
||||
```python
|
||||
# Make an operator acting like batch matric A. Assume A.shape = [..., M, N]
|
||||
# Make an operator acting like batch matrix A. Assume A.shape = [..., M, N]
|
||||
operator = LinearOperator(...)
|
||||
|
||||
X = ... # shape [..., N], batch vector
|
||||
|
@ -26,6 +26,7 @@ from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops.linalg import linear_operator
|
||||
from tensorflow.python.ops.linalg import linear_operator_algebra
|
||||
from tensorflow.python.ops.linalg import linear_operator_util
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -46,7 +47,6 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
the [batch] square matrix formed by having each matrix `Aj` on the main
|
||||
diagonal.
|
||||
|
||||
|
||||
Each `opj` is required to represent a square matrix, and hence will have
|
||||
shape `batch_shape_j + [M_j, M_j]`.
|
||||
|
||||
@ -58,6 +58,12 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
methods may fail due to lack of broadcasting ability in the defining
|
||||
operators' methods.
|
||||
|
||||
Arguments to `matmul`, `matvec`, `solve`, and `solvevec` may either be single
|
||||
`Tensor`s or lists of `Tensor`s that are interpreted as blocks. The `j`th
|
||||
element of a blockwise list of `Tensor`s must have dimensions that match
|
||||
`opj` for the given method. If a list of blocks is input, then a list of
|
||||
blocks is returned as well.
|
||||
|
||||
```python
|
||||
# Create a 4 x 4 linear operator combined of two 2 x 2 operators.
|
||||
operator_1 = LinearOperatorFullMatrix([[1., 2.], [3., 4.]])
|
||||
@ -97,6 +103,11 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
x = tf.random.normal(shape=[2, 3, 9])
|
||||
operator_99.matmul(x)
|
||||
==> Shape [2, 3, 9] Tensor
|
||||
|
||||
# Create a blockwise list of vectors.
|
||||
x = [tf.random.normal(shape=[2, 3, 4]), tf.random.normal(shape=[2, 3, 5])]
|
||||
operator_99.matmul(x)
|
||||
==> [Shape [2, 3, 4] Tensor, Shape [2, 3, 5] Tensor]
|
||||
```
|
||||
|
||||
#### Performance
|
||||
@ -160,6 +171,10 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
"Expected a non-empty list of operators. Found: %s" % operators)
|
||||
self._operators = operators
|
||||
|
||||
# Define diagonal operators, for functions that are shared across blockwise
|
||||
# `LinearOperator` types.
|
||||
self._diagonal_operators = operators
|
||||
|
||||
# Validate dtype.
|
||||
dtype = operators[0].dtype
|
||||
for operator in operators:
|
||||
@ -218,14 +233,22 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
def operators(self):
|
||||
return self._operators
|
||||
|
||||
def _block_range_dimensions(self):
|
||||
return [op.range_dimension for op in self._diagonal_operators]
|
||||
|
||||
def _block_domain_dimensions(self):
|
||||
return [op.domain_dimension for op in self._diagonal_operators]
|
||||
|
||||
def _block_range_dimension_tensors(self):
|
||||
return [op.range_dimension_tensor() for op in self._diagonal_operators]
|
||||
|
||||
def _block_domain_dimension_tensors(self):
|
||||
return [op.domain_dimension_tensor() for op in self._diagonal_operators]
|
||||
|
||||
def _shape(self):
|
||||
# Get final matrix shape.
|
||||
domain_dimension = self.operators[0].domain_dimension
|
||||
range_dimension = self.operators[0].range_dimension
|
||||
for operator in self.operators[1:]:
|
||||
domain_dimension += operator.domain_dimension
|
||||
range_dimension += operator.range_dimension
|
||||
|
||||
domain_dimension = sum(self._block_domain_dimensions())
|
||||
range_dimension = sum(self._block_range_dimensions())
|
||||
matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension])
|
||||
|
||||
# Get broadcast batch shape.
|
||||
@ -243,12 +266,8 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
return ops.convert_to_tensor(
|
||||
self.shape.as_list(), dtype=dtypes.int32, name="shape")
|
||||
|
||||
domain_dimension = self.operators[0].domain_dimension_tensor()
|
||||
range_dimension = self.operators[0].range_dimension_tensor()
|
||||
for operator in self.operators[1:]:
|
||||
domain_dimension += operator.domain_dimension_tensor()
|
||||
range_dimension += operator.range_dimension_tensor()
|
||||
|
||||
domain_dimension = sum(self._block_domain_dimension_tensors())
|
||||
range_dimension = sum(self._block_range_dimension_tensors())
|
||||
matrix_shape = array_ops.stack([domain_dimension, range_dimension])
|
||||
|
||||
# Dummy Tensor of zeros. Will never be materialized.
|
||||
@ -259,19 +278,149 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
|
||||
return array_ops.concat((batch_shape, matrix_shape), 0)
|
||||
|
||||
def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
|
||||
"""Transform [batch] matrix `x` with left multiplication: `x --> Ax`.
|
||||
|
||||
```python
|
||||
# Make an operator acting like batch matrix A. Assume A.shape = [..., M, N]
|
||||
operator = LinearOperator(...)
|
||||
operator.shape = [..., M, N]
|
||||
|
||||
X = ... # shape [..., N, R], batch matrix, R > 0.
|
||||
|
||||
Y = operator.matmul(X)
|
||||
Y.shape
|
||||
==> [..., M, R]
|
||||
|
||||
Y[..., :, r] = sum_j A[..., :, j] X[j, r]
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as
|
||||
`self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See
|
||||
class docstring for definition of shape compatibility.
|
||||
adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`.
|
||||
adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is
|
||||
the hermitian transpose (transposition and complex conjugation).
|
||||
name: A name for this `Op`.
|
||||
|
||||
Returns:
|
||||
A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
|
||||
as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that
|
||||
concatenate to `[..., M, R]`.
|
||||
"""
|
||||
if isinstance(x, linear_operator.LinearOperator):
|
||||
left_operator = self.adjoint() if adjoint else self
|
||||
right_operator = x.adjoint() if adjoint_arg else x
|
||||
|
||||
if (right_operator.range_dimension is not None and
|
||||
left_operator.domain_dimension is not None and
|
||||
right_operator.range_dimension != left_operator.domain_dimension):
|
||||
raise ValueError(
|
||||
"Operators are incompatible. Expected `x` to have dimension"
|
||||
" {} but got {}.".format(
|
||||
left_operator.domain_dimension, right_operator.range_dimension))
|
||||
with self._name_scope(name):
|
||||
return linear_operator_algebra.matmul(left_operator, right_operator)
|
||||
|
||||
with self._name_scope(name):
|
||||
arg_dim = -1 if adjoint_arg else -2
|
||||
block_dimensions = (self._block_range_dimensions() if adjoint
|
||||
else self._block_domain_dimensions())
|
||||
if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim):
|
||||
for i, block in enumerate(x):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
|
||||
x[i] = block
|
||||
else:
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
op_dimension = (self.range_dimension if adjoint
|
||||
else self.domain_dimension)
|
||||
op_dimension.assert_is_compatible_with(x.shape[arg_dim])
|
||||
return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
def _matmul(self, x, adjoint=False, adjoint_arg=False):
|
||||
split_dim = -1 if adjoint_arg else -2
|
||||
# Split input by rows normally, and otherwise columns.
|
||||
split_x = self._split_input_into_blocks(x, axis=split_dim)
|
||||
arg_dim = -1 if adjoint_arg else -2
|
||||
block_dimensions = (self._block_range_dimensions() if adjoint
|
||||
else self._block_domain_dimensions())
|
||||
blockwise_arg = linear_operator_util.arg_is_blockwise(
|
||||
block_dimensions, x, arg_dim)
|
||||
if blockwise_arg:
|
||||
split_x = x
|
||||
else:
|
||||
split_dim = -1 if adjoint_arg else -2
|
||||
# Split input by rows normally, and otherwise columns.
|
||||
split_x = linear_operator_util.split_arg_into_blocks(
|
||||
self._block_domain_dimensions(),
|
||||
self._block_domain_dimension_tensors,
|
||||
x, axis=split_dim)
|
||||
|
||||
result_list = []
|
||||
for index, operator in enumerate(self.operators):
|
||||
result_list += [operator.matmul(
|
||||
split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]
|
||||
|
||||
if blockwise_arg:
|
||||
return result_list
|
||||
|
||||
result_list = linear_operator_util.broadcast_matrix_batch_dims(
|
||||
result_list)
|
||||
return array_ops.concat(result_list, axis=-2)
|
||||
|
||||
def matvec(self, x, adjoint=False, name="matvec"):
|
||||
"""Transform [batch] vector `x` with left multiplication: `x --> Ax`.
|
||||
|
||||
```python
|
||||
# Make an operator acting like batch matric A. Assume A.shape = [..., M, N]
|
||||
operator = LinearOperator(...)
|
||||
|
||||
X = ... # shape [..., N], batch vector
|
||||
|
||||
Y = operator.matvec(X)
|
||||
Y.shape
|
||||
==> [..., M]
|
||||
|
||||
Y[..., :] = sum_j A[..., :, j] X[..., j]
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `Tensor` with compatible shape and same `dtype` as `self`, or an
|
||||
iterable of `Tensor`s (for blockwise operators). `Tensor`s are treated
|
||||
a [batch] vectors, meaning for every set of leading dimensions, the last
|
||||
dimension defines a vector.
|
||||
See class docstring for definition of compatibility.
|
||||
adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`.
|
||||
name: A name for this `Op`.
|
||||
|
||||
Returns:
|
||||
A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
|
||||
"""
|
||||
with self._name_scope(name):
|
||||
block_dimensions = (self._block_range_dimensions() if adjoint
|
||||
else self._block_domain_dimensions())
|
||||
if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
|
||||
for i, block in enumerate(x):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
|
||||
x[i] = block
|
||||
x_mat = [block[..., array_ops.newaxis] for block in x]
|
||||
y_mat = self.matmul(x_mat, adjoint=adjoint)
|
||||
return [array_ops.squeeze(y, axis=-1) for y in y_mat]
|
||||
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
op_dimension = (self.range_dimension if adjoint
|
||||
else self.domain_dimension)
|
||||
op_dimension.assert_is_compatible_with(x.shape[-1])
|
||||
x_mat = x[..., array_ops.newaxis]
|
||||
y_mat = self.matmul(x_mat, adjoint=adjoint)
|
||||
return array_ops.squeeze(y_mat, axis=-1)
|
||||
|
||||
def _determinant(self):
|
||||
result = self.operators[0].determinant()
|
||||
for operator in self.operators[1:]:
|
||||
@ -284,19 +433,172 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
result += operator.log_abs_determinant()
|
||||
return result
|
||||
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
split_dim = -1 if adjoint_arg else -2
|
||||
# Split input by rows normally, and otherwise columns.
|
||||
split_rhs = self._split_input_into_blocks(rhs, axis=split_dim)
|
||||
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
|
||||
"""Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
|
||||
|
||||
solution_list = []
|
||||
for index, operator in enumerate(self.operators):
|
||||
solution_list += [operator.solve(
|
||||
split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]
|
||||
The returned `Tensor` will be close to an exact solution if `A` is well
|
||||
conditioned. Otherwise closeness will vary. See class docstring for details.
|
||||
|
||||
solution_list = linear_operator_util.broadcast_matrix_batch_dims(
|
||||
solution_list)
|
||||
return array_ops.concat(solution_list, axis=-2)
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Make an operator acting like batch matrix A. Assume A.shape = [..., M, N]
|
||||
operator = LinearOperator(...)
|
||||
operator.shape = [..., M, N]
|
||||
|
||||
# Solve R > 0 linear systems for every member of the batch.
|
||||
RHS = ... # shape [..., M, R]
|
||||
|
||||
X = operator.solve(RHS)
|
||||
# X[..., :, r] is the solution to the r'th linear system
|
||||
# sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]
|
||||
|
||||
operator.matmul(X)
|
||||
==> RHS
|
||||
```
|
||||
|
||||
Args:
|
||||
rhs: `Tensor` with same `dtype` as this operator and compatible shape,
|
||||
or a list of `Tensor`s (for blockwise operators). `Tensor`s are treated
|
||||
like a [batch] matrices meaning for every set of leading dimensions, the
|
||||
last two dimensions defines a matrix.
|
||||
See class docstring for definition of compatibility.
|
||||
adjoint: Python `bool`. If `True`, solve the system involving the adjoint
|
||||
of this `LinearOperator`: `A^H X = rhs`.
|
||||
adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H`
|
||||
is the hermitian transpose (transposition and complex conjugation).
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
`Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If `self.is_non_singular` or `is_square` is False.
|
||||
"""
|
||||
if self.is_non_singular is False:
|
||||
raise NotImplementedError(
|
||||
"Exact solve not implemented for an operator that is expected to "
|
||||
"be singular.")
|
||||
if self.is_square is False:
|
||||
raise NotImplementedError(
|
||||
"Exact solve not implemented for an operator that is expected to "
|
||||
"not be square.")
|
||||
if isinstance(rhs, linear_operator.LinearOperator):
|
||||
left_operator = self.adjoint() if adjoint else self
|
||||
right_operator = rhs.adjoint() if adjoint_arg else rhs
|
||||
|
||||
if (right_operator.range_dimension is not None and
|
||||
left_operator.domain_dimension is not None and
|
||||
right_operator.range_dimension != left_operator.domain_dimension):
|
||||
raise ValueError(
|
||||
"Operators are incompatible. Expected `rhs` to have dimension"
|
||||
" {} but got {}.".format(
|
||||
left_operator.domain_dimension, right_operator.range_dimension))
|
||||
with self._name_scope(name):
|
||||
return linear_operator_algebra.solve(left_operator, right_operator)
|
||||
|
||||
with self._name_scope(name):
|
||||
block_dimensions = (self._block_domain_dimensions() if adjoint
|
||||
else self._block_range_dimensions())
|
||||
arg_dim = -1 if adjoint_arg else -2
|
||||
blockwise_arg = linear_operator_util.arg_is_blockwise(
|
||||
block_dimensions, rhs, arg_dim)
|
||||
|
||||
if blockwise_arg:
|
||||
split_rhs = rhs
|
||||
for i, block in enumerate(split_rhs):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
|
||||
split_rhs[i] = block
|
||||
else:
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||
self._check_input_dtype(rhs)
|
||||
op_dimension = (self.domain_dimension if adjoint
|
||||
else self.range_dimension)
|
||||
op_dimension.assert_is_compatible_with(rhs.shape[arg_dim])
|
||||
split_dim = -1 if adjoint_arg else -2
|
||||
# Split input by rows normally, and otherwise columns.
|
||||
split_rhs = linear_operator_util.split_arg_into_blocks(
|
||||
self._block_domain_dimensions(),
|
||||
self._block_domain_dimension_tensors,
|
||||
rhs, axis=split_dim)
|
||||
|
||||
solution_list = []
|
||||
for index, operator in enumerate(self.operators):
|
||||
solution_list += [operator.solve(
|
||||
split_rhs[index], adjoint=adjoint, adjoint_arg=adjoint_arg)]
|
||||
|
||||
if blockwise_arg:
|
||||
return solution_list
|
||||
|
||||
solution_list = linear_operator_util.broadcast_matrix_batch_dims(
|
||||
solution_list)
|
||||
return array_ops.concat(solution_list, axis=-2)
|
||||
|
||||
def solvevec(self, rhs, adjoint=False, name="solve"):
|
||||
"""Solve single equation with best effort: `A X = rhs`.
|
||||
|
||||
The returned `Tensor` will be close to an exact solution if `A` is well
|
||||
conditioned. Otherwise closeness will vary. See class docstring for details.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Make an operator acting like batch matrix A. Assume A.shape = [..., M, N]
|
||||
operator = LinearOperator(...)
|
||||
operator.shape = [..., M, N]
|
||||
|
||||
# Solve one linear system for every member of the batch.
|
||||
RHS = ... # shape [..., M]
|
||||
|
||||
X = operator.solvevec(RHS)
|
||||
# X is the solution to the linear system
|
||||
# sum_j A[..., :, j] X[..., j] = RHS[..., :]
|
||||
|
||||
operator.matvec(X)
|
||||
==> RHS
|
||||
```
|
||||
|
||||
Args:
|
||||
rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s
|
||||
(for blockwise operators). `Tensor`s are treated as [batch] vectors,
|
||||
meaning for every set of leading dimensions, the last dimension defines
|
||||
a vector. See class docstring for definition of compatibility regarding
|
||||
batch dimensions.
|
||||
adjoint: Python `bool`. If `True`, solve the system involving the adjoint
|
||||
of this `LinearOperator`: `A^H X = rhs`.
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
`Tensor` with shape `[...,N]` and same `dtype` as `rhs`.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If `self.is_non_singular` or `is_square` is False.
|
||||
"""
|
||||
with self._name_scope(name):
|
||||
block_dimensions = (self._block_domain_dimensions() if adjoint
|
||||
else self._block_range_dimensions())
|
||||
if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1):
|
||||
for i, block in enumerate(rhs):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
|
||||
rhs[i] = block
|
||||
rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs]
|
||||
solution_mat = self.solve(rhs_mat, adjoint=adjoint)
|
||||
return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
|
||||
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||
self._check_input_dtype(rhs)
|
||||
op_dimension = (self.domain_dimension if adjoint
|
||||
else self.range_dimension)
|
||||
op_dimension.assert_is_compatible_with(rhs.shape[-1])
|
||||
rhs_mat = array_ops.expand_dims(rhs, axis=-1)
|
||||
solution_mat = self.solve(rhs_mat, adjoint=adjoint)
|
||||
return array_ops.squeeze(solution_mat, axis=-1)
|
||||
|
||||
def _diag_part(self):
|
||||
diag_list = []
|
||||
@ -360,27 +662,3 @@ class LinearOperatorBlockDiag(linear_operator.LinearOperator):
|
||||
eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list)
|
||||
eigs = array_ops.concat(eig_list, axis=-2)
|
||||
return array_ops.squeeze(eigs, axis=-1)
|
||||
|
||||
def _split_input_into_blocks(self, x, axis=-1):
|
||||
"""Split `x` into blocks matching `operators`'s `domain_dimension`.
|
||||
|
||||
Specifically, if we have a block diagonal matrix, with block sizes
|
||||
`[M_j, M_j] j = 1..J`, this method splits `x` on `axis` into `J`
|
||||
tensors, whose shape at `axis` is `M_j`.
|
||||
|
||||
Args:
|
||||
x: `Tensor`. `x` is split into `J` tensors.
|
||||
axis: Python `Integer` representing the axis to split `x` on.
|
||||
|
||||
Returns:
|
||||
A list of `Tensor`s.
|
||||
"""
|
||||
block_sizes = []
|
||||
if self.shape.is_fully_defined():
|
||||
for operator in self.operators:
|
||||
block_sizes += [operator.domain_dimension.value]
|
||||
else:
|
||||
for operator in self.operators:
|
||||
block_sizes += [operator.domain_dimension_tensor()]
|
||||
|
||||
return array_ops.split(x, block_sizes, axis=axis)
|
||||
|
@ -28,6 +28,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops.linalg import linalg_impl as linalg
|
||||
from tensorflow.python.ops.linalg import linear_operator
|
||||
from tensorflow.python.ops.linalg import linear_operator_algebra
|
||||
from tensorflow.python.ops.linalg import linear_operator_util
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
@ -147,6 +148,16 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
>>> y.shape
|
||||
TensorShape([2, 3, 9])
|
||||
|
||||
Create a blockwise list of vectors and apply the operator to it. A blockwise
|
||||
list is returned.
|
||||
>>> x4 = tf.random.normal(shape=[2, 1, 4])
|
||||
>>> x5 = tf.random.normal(shape=[2, 3, 5])
|
||||
>>> y_blockwise = operator_99.matvec([x4, x5])
|
||||
>>> y_blockwise[0].shape
|
||||
TensorShape([2, 3, 4])
|
||||
>>> y_blockwise[1].shape
|
||||
TensorShape([2, 3, 5])
|
||||
|
||||
#### Performance
|
||||
|
||||
Suppose `operator` is a `LinearOperatorBlockLowerTriangular` consisting of `D`
|
||||
@ -230,6 +241,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
raise ValueError(
|
||||
"Expected a non-empty list of operators. Found: {}".format(operators))
|
||||
self._operators = operators
|
||||
self._diagonal_operators = [row[-1] for row in operators]
|
||||
|
||||
dtype = operators[0][0].dtype
|
||||
self._validate_dtype(dtype)
|
||||
@ -287,13 +299,13 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
|
||||
# pylint: disable=g-bool-id-comparison
|
||||
def _validate_non_singular(self, is_non_singular):
|
||||
if all(row[-1].is_non_singular for row in self.operators):
|
||||
if all(op.is_non_singular for op in self._diagonal_operators):
|
||||
if is_non_singular is False:
|
||||
raise ValueError(
|
||||
"A blockwise lower-triangular operator with non-singular operators "
|
||||
" on the main diagonal is always non-singular.")
|
||||
return True
|
||||
if any(row[-1].is_non_singular is False for row in self.operators):
|
||||
if any(op.is_non_singular is False for op in self._diagonal_operators):
|
||||
if is_non_singular is True:
|
||||
raise ValueError(
|
||||
"A blockwise lower-triangular operator with a singular operator on "
|
||||
@ -303,7 +315,7 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
def _validate_square(self, is_square):
|
||||
if is_square is False:
|
||||
raise ValueError("`LinearOperatorBlockLowerTriangular` must be square.")
|
||||
if any(row[-1].is_square is False for row in self.operators):
|
||||
if any(op.is_square is False for op in self._diagonal_operators):
|
||||
raise ValueError(
|
||||
"Matrices on the diagonal (the final elements of each row-partition "
|
||||
"in the `operators` list) must be square.")
|
||||
@ -323,14 +335,22 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
def operators(self):
|
||||
return self._operators
|
||||
|
||||
def _block_range_dimensions(self):
|
||||
return [op.range_dimension for op in self._diagonal_operators]
|
||||
|
||||
def _block_domain_dimensions(self):
|
||||
return [op.domain_dimension for op in self._diagonal_operators]
|
||||
|
||||
def _block_range_dimension_tensors(self):
|
||||
return [op.range_dimension_tensor() for op in self._diagonal_operators]
|
||||
|
||||
def _block_domain_dimension_tensors(self):
|
||||
return [op.domain_dimension_tensor() for op in self._diagonal_operators]
|
||||
|
||||
def _shape(self):
|
||||
# Get final matrix shape.
|
||||
domain_dimension = self.operators[0][0].domain_dimension
|
||||
range_dimension = self.operators[0][0].range_dimension
|
||||
for row in self.operators[1:]:
|
||||
domain_dimension += row[-1].domain_dimension
|
||||
range_dimension += row[-1].range_dimension
|
||||
|
||||
domain_dimension = sum(self._block_domain_dimensions())
|
||||
range_dimension = sum(self._block_range_dimensions())
|
||||
matrix_shape = tensor_shape.TensorShape([domain_dimension, range_dimension])
|
||||
|
||||
# Get broadcast batch shape.
|
||||
@ -349,13 +369,8 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
return ops.convert_to_tensor(
|
||||
self.shape.as_list(), dtype=dtypes.int32, name="shape")
|
||||
|
||||
domain_dimension = self.operators[0][0].domain_dimension_tensor()
|
||||
range_dimension = self.operators[0][0].range_dimension_tensor()
|
||||
|
||||
for row in self.operators[1:]:
|
||||
domain_dimension += row[-1].domain_dimension_tensor()
|
||||
range_dimension += row[-1].range_dimension_tensor()
|
||||
|
||||
domain_dimension = sum(self._block_domain_dimension_tensors())
|
||||
range_dimension = sum(self._block_range_dimension_tensors())
|
||||
matrix_shape = array_ops.stack([domain_dimension, range_dimension])
|
||||
|
||||
batch_shape = self.operators[0][0].batch_shape_tensor()
|
||||
@ -366,17 +381,92 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
|
||||
return array_ops.concat((batch_shape, matrix_shape), 0)
|
||||
|
||||
def matmul(self, x, adjoint=False, adjoint_arg=False, name="matmul"):
|
||||
"""Transform [batch] matrix `x` with left multiplication: `x --> Ax`.
|
||||
|
||||
```python
|
||||
# Make an operator acting like batch matrix A. Assume A.shape = [..., M, N]
|
||||
operator = LinearOperator(...)
|
||||
operator.shape = [..., M, N]
|
||||
|
||||
X = ... # shape [..., N, R], batch matrix, R > 0.
|
||||
|
||||
Y = operator.matmul(X)
|
||||
Y.shape
|
||||
==> [..., M, R]
|
||||
|
||||
Y[..., :, r] = sum_j A[..., :, j] X[j, r]
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `LinearOperator`, `Tensor` with compatible shape and same `dtype` as
|
||||
`self`, or a blockwise iterable of `LinearOperator`s or `Tensor`s. See
|
||||
class docstring for definition of shape compatibility.
|
||||
adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`.
|
||||
adjoint_arg: Python `bool`. If `True`, compute `A x^H` where `x^H` is
|
||||
the hermitian transpose (transposition and complex conjugation).
|
||||
name: A name for this `Op`.
|
||||
|
||||
Returns:
|
||||
A `LinearOperator` or `Tensor` with shape `[..., M, R]` and same `dtype`
|
||||
as `self`, or if `x` is blockwise, a list of `Tensor`s with shapes that
|
||||
concatenate to `[..., M, R]`.
|
||||
"""
|
||||
if isinstance(x, linear_operator.LinearOperator):
|
||||
left_operator = self.adjoint() if adjoint else self
|
||||
right_operator = x.adjoint() if adjoint_arg else x
|
||||
|
||||
if (right_operator.range_dimension is not None and
|
||||
left_operator.domain_dimension is not None and
|
||||
right_operator.range_dimension != left_operator.domain_dimension):
|
||||
raise ValueError(
|
||||
"Operators are incompatible. Expected `x` to have dimension"
|
||||
" {} but got {}.".format(
|
||||
left_operator.domain_dimension, right_operator.range_dimension))
|
||||
with self._name_scope(name):
|
||||
return linear_operator_algebra.matmul(left_operator, right_operator)
|
||||
|
||||
with self._name_scope(name):
|
||||
arg_dim = -1 if adjoint_arg else -2
|
||||
block_dimensions = (self._block_range_dimensions() if adjoint
|
||||
else self._block_domain_dimensions())
|
||||
if linear_operator_util.arg_is_blockwise(block_dimensions, x, arg_dim):
|
||||
for i, block in enumerate(x):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
|
||||
x[i] = block
|
||||
else:
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
op_dimension = (self.range_dimension if adjoint
|
||||
else self.domain_dimension)
|
||||
op_dimension.assert_is_compatible_with(x.shape[arg_dim])
|
||||
return self._matmul(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
def _matmul(self, x, adjoint=False, adjoint_arg=False):
|
||||
split_dim = -1 if adjoint_arg else -2
|
||||
# Split input by columns if adjoint_arg is True, else rows
|
||||
split_x = self._split_input_into_blocks(x, axis=split_dim)
|
||||
arg_dim = -1 if adjoint_arg else -2
|
||||
block_dimensions = (self._block_range_dimensions() if adjoint
|
||||
else self._block_domain_dimensions())
|
||||
blockwise_arg = linear_operator_util.arg_is_blockwise(
|
||||
block_dimensions, x, arg_dim)
|
||||
if blockwise_arg:
|
||||
split_x = x
|
||||
else:
|
||||
split_dim = -1 if adjoint_arg else -2
|
||||
# Split input by columns if adjoint_arg is True, else rows
|
||||
split_x = linear_operator_util.split_arg_into_blocks(
|
||||
self._block_domain_dimensions(),
|
||||
self._block_domain_dimension_tensors,
|
||||
x, axis=split_dim)
|
||||
|
||||
result_list = []
|
||||
# Iterate over row-partitions (i.e. column-partitions of the adjoint).
|
||||
if adjoint:
|
||||
for index in range(len(self.operators)):
|
||||
# Begin with the operator on the diagonal and apply it to the respective
|
||||
# `rhs` block.
|
||||
# Begin with the operator on the diagonal and apply it to the
|
||||
# respective `rhs` block.
|
||||
result = self.operators[index][index].matmul(
|
||||
split_x[index], adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
@ -400,8 +490,8 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
result_list.append(result)
|
||||
else:
|
||||
for row in self.operators:
|
||||
# Begin with the left-most operator in the row-partition and apply it to
|
||||
# the first `rhs` block.
|
||||
# Begin with the left-most operator in the row-partition and apply it
|
||||
# to the first `rhs` block.
|
||||
result = row[0].matmul(
|
||||
split_x[0], adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
# Iterate left to right over the operators in the remainder of the row
|
||||
@ -412,117 +502,329 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
split_x[j + 1], adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
result_list.append(result)
|
||||
|
||||
if blockwise_arg:
|
||||
return result_list
|
||||
|
||||
result_list = linear_operator_util.broadcast_matrix_batch_dims(
|
||||
result_list)
|
||||
return array_ops.concat(result_list, axis=-2)
|
||||
|
||||
def matvec(self, x, adjoint=False, name="matvec"):
|
||||
"""Transform [batch] vector `x` with left multiplication: `x --> Ax`.
|
||||
|
||||
```python
|
||||
# Make an operator acting like batch matrix A. Assume A.shape = [..., M, N]
|
||||
operator = LinearOperator(...)
|
||||
|
||||
X = ... # shape [..., N], batch vector
|
||||
|
||||
Y = operator.matvec(X)
|
||||
Y.shape
|
||||
==> [..., M]
|
||||
|
||||
Y[..., :] = sum_j A[..., :, j] X[..., j]
|
||||
```
|
||||
|
||||
Args:
|
||||
x: `Tensor` with compatible shape and same `dtype` as `self`, or an
|
||||
iterable of `Tensor`s. `Tensor`s are treated a [batch] vectors, meaning
|
||||
for every set of leading dimensions, the last dimension defines a
|
||||
vector.
|
||||
See class docstring for definition of compatibility.
|
||||
adjoint: Python `bool`. If `True`, left multiply by the adjoint: `A^H x`.
|
||||
name: A name for this `Op`.
|
||||
|
||||
Returns:
|
||||
A `Tensor` with shape `[..., M]` and same `dtype` as `self`.
|
||||
"""
|
||||
with self._name_scope(name):
|
||||
block_dimensions = (self._block_range_dimensions() if adjoint
|
||||
else self._block_domain_dimensions())
|
||||
if linear_operator_util.arg_is_blockwise(block_dimensions, x, -1):
|
||||
for i, block in enumerate(x):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
|
||||
x[i] = block
|
||||
x_mat = [block[..., array_ops.newaxis] for block in x]
|
||||
y_mat = self.matmul(x_mat, adjoint=adjoint)
|
||||
return [array_ops.squeeze(y, axis=-1) for y in y_mat]
|
||||
|
||||
x = ops.convert_to_tensor(x, name="x")
|
||||
self._check_input_dtype(x)
|
||||
op_dimension = (self.range_dimension if adjoint
|
||||
else self.domain_dimension)
|
||||
op_dimension.assert_is_compatible_with(x.shape[-1])
|
||||
x_mat = x[..., array_ops.newaxis]
|
||||
y_mat = self.matmul(x_mat, adjoint=adjoint)
|
||||
return array_ops.squeeze(y_mat, axis=-1)
|
||||
|
||||
def _determinant(self):
|
||||
if all(row[-1].is_positive_definite for row in self.operators):
|
||||
if all(op.is_positive_definite for op in self._diagonal_operators):
|
||||
return math_ops.exp(self._log_abs_determinant())
|
||||
result = self.operators[0][0].determinant()
|
||||
for row in self.operators[1:]:
|
||||
result *= row[-1].determinant()
|
||||
result = self._diagonal_operators[0].determinant()
|
||||
for op in self._diagonal_operators[1:]:
|
||||
result *= op.determinant()
|
||||
return result
|
||||
|
||||
def _log_abs_determinant(self):
|
||||
result = self.operators[0][0].log_abs_determinant()
|
||||
for row in self.operators[1:]:
|
||||
result += row[-1].log_abs_determinant()
|
||||
result = self._diagonal_operators[0].log_abs_determinant()
|
||||
for op in self._diagonal_operators[1:]:
|
||||
result += op.log_abs_determinant()
|
||||
return result
|
||||
|
||||
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
|
||||
# Given the blockwise `n + 1`-by-`n + 1` linear operator:
|
||||
#
|
||||
# op = [[A_00 0 ... 0 ... 0],
|
||||
# [A_10 A_11 ... 0 ... 0],
|
||||
# ...
|
||||
# [A_k0 A_k1 ... A_kk ... 0],
|
||||
# ...
|
||||
# [A_n0 A_n1 ... A_nk ... A_nn]]
|
||||
#
|
||||
# we find `x = op.solve(y)` by observing that
|
||||
#
|
||||
# `y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)`
|
||||
#
|
||||
# and therefore
|
||||
#
|
||||
# `x_k = A_kk.solve(y_k -
|
||||
# A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))`
|
||||
#
|
||||
# where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x`
|
||||
# and `y` along their appropriate axes.
|
||||
#
|
||||
# We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve
|
||||
# for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`.
|
||||
#
|
||||
# The adjoint case is solved similarly, beginning with
|
||||
# `x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards.
|
||||
rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
|
||||
split_rhs = self._split_input_into_blocks(rhs, axis=-2)
|
||||
def solve(self, rhs, adjoint=False, adjoint_arg=False, name="solve"):
|
||||
"""Solve (exact or approx) `R` (batch) systems of equations: `A X = rhs`.
|
||||
|
||||
solution_list = []
|
||||
if adjoint:
|
||||
# For an adjoint blockwise lower-triangular linear operator, the system
|
||||
# must be solved bottom to top. Iterate backwards over rows of the adjoint
|
||||
# (i.e. columns of the non-adjoint operator).
|
||||
for index in reversed(range(len(self.operators))):
|
||||
y = split_rhs[index]
|
||||
# Iterate top to bottom over the operators in the off-diagonal portion
|
||||
# of the column-partition (i.e. row-partition of the adjoint), apply
|
||||
# the operator to the respective block of the solution found in previous
|
||||
# iterations, and subtract the result from the `rhs` block. For example,
|
||||
# let `A`, `B`, and `D` be the linear operators in the top row-partition
|
||||
# of the adjoint of
|
||||
# `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`,
|
||||
# and `x_1` and `x_2` be blocks of the solution found in previous
|
||||
# iterations of the outer loop. The following loop (when `index == 0`)
|
||||
# expresses
|
||||
# `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where
|
||||
# `y_0* = y_0 - Bx_1 - Dx_2`.
|
||||
for j in reversed(range(index + 1, len(self.operators))):
|
||||
y -= self.operators[j][index].matmul(
|
||||
solution_list[len(self.operators) - 1 - j],
|
||||
adjoint=adjoint)
|
||||
# Continuing the example above, solve `Ax_0 = y_0*` for `x_0`.
|
||||
solution_list.append(
|
||||
self.operators[index][index].solve(y, adjoint=adjoint))
|
||||
solution_list.reverse()
|
||||
else:
|
||||
# Iterate top to bottom over the row-partitions.
|
||||
for row, y in zip(self.operators, split_rhs):
|
||||
# Iterate left to right over the operators in the off-diagonal portion
|
||||
# of the row-partition, apply the operator to the block of the solution
|
||||
# found in previous iterations, and subtract the result from the `rhs`
|
||||
# block. For example, let `D`, `E`, and `F` be the linear operators in
|
||||
# the bottom row-partition of
|
||||
# `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and
|
||||
# `x_0` and `x_1` be blocks of the solution found in previous iterations
|
||||
# of the outer loop. The following loop (when `index == 2`), expresses
|
||||
# `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where
|
||||
# `y_2* = y_2 - D_x0 - Ex_1`.
|
||||
for i, operator in enumerate(row[:-1]):
|
||||
y -= operator.matmul(solution_list[i], adjoint=adjoint)
|
||||
# Continuing the example above, solve `Fx_2 = y_2*` for `x_2`.
|
||||
solution_list.append(row[-1].solve(y, adjoint=adjoint))
|
||||
The returned `Tensor` will be close to an exact solution if `A` is well
|
||||
conditioned. Otherwise closeness will vary. See class docstring for details.
|
||||
|
||||
solution_list = linear_operator_util.broadcast_matrix_batch_dims(
|
||||
solution_list)
|
||||
return array_ops.concat(solution_list, axis=-2)
|
||||
Given the blockwise `n + 1`-by-`n + 1` linear operator:
|
||||
|
||||
op = [[A_00 0 ... 0 ... 0],
|
||||
[A_10 A_11 ... 0 ... 0],
|
||||
...
|
||||
[A_k0 A_k1 ... A_kk ... 0],
|
||||
...
|
||||
[A_n0 A_n1 ... A_nk ... A_nn]]
|
||||
|
||||
we find `x = op.solve(y)` by observing that
|
||||
|
||||
`y_k = A_k0.matmul(x_0) + A_k1.matmul(x_1) + ... + A_kk.matmul(x_k)`
|
||||
|
||||
and therefore
|
||||
|
||||
`x_k = A_kk.solve(y_k -
|
||||
A_k0.matmul(x_0) - ... - A_k(k-1).matmul(x_(k-1)))`
|
||||
|
||||
where `x_k` and `y_k` are the `k`th blocks obtained by decomposing `x`
|
||||
and `y` along their appropriate axes.
|
||||
|
||||
We first solve `x_0 = A_00.solve(y_0)`. Proceeding inductively, we solve
|
||||
for `x_k`, `k = 1..n`, given `x_0..x_(k-1)`.
|
||||
|
||||
The adjoint case is solved similarly, beginning with
|
||||
`x_n = A_nn.solve(y_n, adjoint=True)` and proceeding backwards.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Make an operator acting like batch matrix A. Assume A.shape = [..., M, N]
|
||||
operator = LinearOperator(...)
|
||||
operator.shape = [..., M, N]
|
||||
|
||||
# Solve R > 0 linear systems for every member of the batch.
|
||||
RHS = ... # shape [..., M, R]
|
||||
|
||||
X = operator.solve(RHS)
|
||||
# X[..., :, r] is the solution to the r'th linear system
|
||||
# sum_j A[..., :, j] X[..., j, r] = RHS[..., :, r]
|
||||
|
||||
operator.matmul(X)
|
||||
==> RHS
|
||||
```
|
||||
|
||||
Args:
|
||||
rhs: `Tensor` with same `dtype` as this operator and compatible shape,
|
||||
or a list of `Tensor`s. `Tensor`s are treated like a [batch] matrices
|
||||
meaning for every set of leading dimensions, the last two dimensions
|
||||
defines a matrix.
|
||||
See class docstring for definition of compatibility.
|
||||
adjoint: Python `bool`. If `True`, solve the system involving the adjoint
|
||||
of this `LinearOperator`: `A^H X = rhs`.
|
||||
adjoint_arg: Python `bool`. If `True`, solve `A X = rhs^H` where `rhs^H`
|
||||
is the hermitian transpose (transposition and complex conjugation).
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
`Tensor` with shape `[...,N, R]` and same `dtype` as `rhs`.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If `self.is_non_singular` or `is_square` is False.
|
||||
"""
|
||||
if self.is_non_singular is False:
|
||||
raise NotImplementedError(
|
||||
"Exact solve not implemented for an operator that is expected to "
|
||||
"be singular.")
|
||||
if self.is_square is False:
|
||||
raise NotImplementedError(
|
||||
"Exact solve not implemented for an operator that is expected to "
|
||||
"not be square.")
|
||||
if isinstance(rhs, linear_operator.LinearOperator):
|
||||
left_operator = self.adjoint() if adjoint else self
|
||||
right_operator = rhs.adjoint() if adjoint_arg else rhs
|
||||
|
||||
if (right_operator.range_dimension is not None and
|
||||
left_operator.domain_dimension is not None and
|
||||
right_operator.range_dimension != left_operator.domain_dimension):
|
||||
raise ValueError(
|
||||
"Operators are incompatible. Expected `rhs` to have dimension"
|
||||
" {} but got {}.".format(
|
||||
left_operator.domain_dimension, right_operator.range_dimension))
|
||||
with self._name_scope(name):
|
||||
return linear_operator_algebra.solve(left_operator, right_operator)
|
||||
|
||||
with self._name_scope(name):
|
||||
block_dimensions = (self._block_domain_dimensions() if adjoint
|
||||
else self._block_range_dimensions())
|
||||
arg_dim = -1 if adjoint_arg else -2
|
||||
blockwise_arg = linear_operator_util.arg_is_blockwise(
|
||||
block_dimensions, rhs, arg_dim)
|
||||
if blockwise_arg:
|
||||
for i, block in enumerate(rhs):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[arg_dim])
|
||||
rhs[i] = block
|
||||
if adjoint_arg:
|
||||
split_rhs = [linalg.adjoint(y) for y in rhs]
|
||||
else:
|
||||
split_rhs = rhs
|
||||
|
||||
else:
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||
self._check_input_dtype(rhs)
|
||||
op_dimension = (self.domain_dimension if adjoint
|
||||
else self.range_dimension)
|
||||
op_dimension.assert_is_compatible_with(rhs.shape[arg_dim])
|
||||
|
||||
rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
|
||||
split_rhs = linear_operator_util.split_arg_into_blocks(
|
||||
self._block_domain_dimensions(),
|
||||
self._block_domain_dimension_tensors,
|
||||
rhs, axis=-2)
|
||||
|
||||
solution_list = []
|
||||
if adjoint:
|
||||
# For an adjoint blockwise lower-triangular linear operator, the system
|
||||
# must be solved bottom to top. Iterate backwards over rows of the
|
||||
# adjoint (i.e. columns of the non-adjoint operator).
|
||||
for index in reversed(range(len(self.operators))):
|
||||
y = split_rhs[index]
|
||||
# Iterate top to bottom over the operators in the off-diagonal portion
|
||||
# of the column-partition (i.e. row-partition of the adjoint), apply
|
||||
# the operator to the respective block of the solution found in
|
||||
# previous iterations, and subtract the result from the `rhs` block.
|
||||
# For example,let `A`, `B`, and `D` be the linear operators in the top
|
||||
# row-partition of the adjoint of
|
||||
# `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])`,
|
||||
# and `x_1` and `x_2` be blocks of the solution found in previous
|
||||
# iterations of the outer loop. The following loop (when `index == 0`)
|
||||
# expresses
|
||||
# `Ax_0 + Bx_1 + Dx_2 = y_0` as `Ax_0 = y_0*`, where
|
||||
# `y_0* = y_0 - Bx_1 - Dx_2`.
|
||||
for j in reversed(range(index + 1, len(self.operators))):
|
||||
y -= self.operators[j][index].matmul(
|
||||
solution_list[len(self.operators) - 1 - j],
|
||||
adjoint=adjoint)
|
||||
# Continuing the example above, solve `Ax_0 = y_0*` for `x_0`.
|
||||
solution_list.append(
|
||||
self._diagonal_operators[index].solve(y, adjoint=adjoint))
|
||||
solution_list.reverse()
|
||||
else:
|
||||
# Iterate top to bottom over the row-partitions.
|
||||
for row, y in zip(self.operators, split_rhs):
|
||||
# Iterate left to right over the operators in the off-diagonal portion
|
||||
# of the row-partition, apply the operator to the block of the
|
||||
# solution found in previous iterations, and subtract the result from
|
||||
# the `rhs` block. For example, let `D`, `E`, and `F` be the linear
|
||||
# operators in the bottom row-partition of
|
||||
# `LinearOperatorBlockLowerTriangular([[A], [B, C], [D, E, F]])` and
|
||||
# `x_0` and `x_1` be blocks of the solution found in previous
|
||||
# iterations of the outer loop. The following loop
|
||||
# (when `index == 2`), expresses
|
||||
# `Dx_0 + Ex_1 + Fx_2 = y_2` as `Fx_2 = y_2*`, where
|
||||
# `y_2* = y_2 - D_x0 - Ex_1`.
|
||||
for i, operator in enumerate(row[:-1]):
|
||||
y -= operator.matmul(solution_list[i], adjoint=adjoint)
|
||||
# Continuing the example above, solve `Fx_2 = y_2*` for `x_2`.
|
||||
solution_list.append(row[-1].solve(y, adjoint=adjoint))
|
||||
|
||||
if blockwise_arg:
|
||||
return solution_list
|
||||
|
||||
solution_list = linear_operator_util.broadcast_matrix_batch_dims(
|
||||
solution_list)
|
||||
return array_ops.concat(solution_list, axis=-2)
|
||||
|
||||
def solvevec(self, rhs, adjoint=False, name="solve"):
|
||||
"""Solve single equation with best effort: `A X = rhs`.
|
||||
|
||||
The returned `Tensor` will be close to an exact solution if `A` is well
|
||||
conditioned. Otherwise closeness will vary. See class docstring for details.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# Make an operator acting like batch matrix A. Assume A.shape = [..., M, N]
|
||||
operator = LinearOperator(...)
|
||||
operator.shape = [..., M, N]
|
||||
|
||||
# Solve one linear system for every member of the batch.
|
||||
RHS = ... # shape [..., M]
|
||||
|
||||
X = operator.solvevec(RHS)
|
||||
# X is the solution to the linear system
|
||||
# sum_j A[..., :, j] X[..., j] = RHS[..., :]
|
||||
|
||||
operator.matvec(X)
|
||||
==> RHS
|
||||
```
|
||||
|
||||
Args:
|
||||
rhs: `Tensor` with same `dtype` as this operator, or list of `Tensor`s
|
||||
(for blockwise operators). `Tensor`s are treated as [batch] vectors,
|
||||
meaning for every set of leading dimensions, the last dimension defines
|
||||
a vector. See class docstring for definition of compatibility regarding
|
||||
batch dimensions.
|
||||
adjoint: Python `bool`. If `True`, solve the system involving the adjoint
|
||||
of this `LinearOperator`: `A^H X = rhs`.
|
||||
name: A name scope to use for ops added by this method.
|
||||
|
||||
Returns:
|
||||
`Tensor` with shape `[...,N]` and same `dtype` as `rhs`.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If `self.is_non_singular` or `is_square` is False.
|
||||
"""
|
||||
with self._name_scope(name):
|
||||
block_dimensions = (self._block_domain_dimensions() if adjoint
|
||||
else self._block_range_dimensions())
|
||||
if linear_operator_util.arg_is_blockwise(block_dimensions, rhs, -1):
|
||||
for i, block in enumerate(rhs):
|
||||
if not isinstance(block, linear_operator.LinearOperator):
|
||||
block = ops.convert_to_tensor(block)
|
||||
self._check_input_dtype(block)
|
||||
block_dimensions[i].assert_is_compatible_with(block.shape[-1])
|
||||
rhs[i] = block
|
||||
rhs_mat = [array_ops.expand_dims(block, axis=-1) for block in rhs]
|
||||
solution_mat = self.solve(rhs_mat, adjoint=adjoint)
|
||||
return [array_ops.squeeze(x, axis=-1) for x in solution_mat]
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs")
|
||||
self._check_input_dtype(rhs)
|
||||
op_dimension = (self.domain_dimension if adjoint
|
||||
else self.range_dimension)
|
||||
op_dimension.assert_is_compatible_with(rhs.shape[-1])
|
||||
rhs_mat = array_ops.expand_dims(rhs, axis=-1)
|
||||
solution_mat = self.solve(rhs_mat, adjoint=adjoint)
|
||||
return array_ops.squeeze(solution_mat, axis=-1)
|
||||
|
||||
def _diag_part(self):
|
||||
diag_list = []
|
||||
for row in self.operators:
|
||||
for op in self._diagonal_operators:
|
||||
# Extend the axis, since `broadcast_matrix_batch_dims` treats all but the
|
||||
# final two dimensions as batch dimensions.
|
||||
diag_list.append(row[-1].diag_part()[..., array_ops.newaxis])
|
||||
diag_list.append(op.diag_part()[..., array_ops.newaxis])
|
||||
diag_list = linear_operator_util.broadcast_matrix_batch_dims(diag_list)
|
||||
diagonal = array_ops.concat(diag_list, axis=-2)
|
||||
return array_ops.squeeze(diagonal, axis=-1)
|
||||
|
||||
def _trace(self):
|
||||
result = self.operators[0][0].trace()
|
||||
for row in self.operators[1:]:
|
||||
result += row[-1].trace()
|
||||
result = self._diagonal_operators[0].trace()
|
||||
for op in self._diagonal_operators[1:]:
|
||||
result += op.trace()
|
||||
return result
|
||||
|
||||
def _to_dense(self):
|
||||
@ -551,37 +853,13 @@ class LinearOperatorBlockLowerTriangular(linear_operator.LinearOperator):
|
||||
|
||||
def _assert_non_singular(self):
|
||||
return control_flow_ops.group([
|
||||
row[-1].assert_non_singular() for row in self.operators])
|
||||
op.assert_non_singular() for op in self._diagonal_operators])
|
||||
|
||||
def _eigvals(self):
|
||||
eig_list = []
|
||||
for row in self.operators:
|
||||
for op in self._diagonal_operators:
|
||||
# Extend the axis for broadcasting.
|
||||
eig_list.append(row[-1].eigvals()[..., array_ops.newaxis])
|
||||
eig_list.append(op.eigvals()[..., array_ops.newaxis])
|
||||
eig_list = linear_operator_util.broadcast_matrix_batch_dims(eig_list)
|
||||
eigs = array_ops.concat(eig_list, axis=-2)
|
||||
return array_ops.squeeze(eigs, axis=-1)
|
||||
|
||||
def _split_input_into_blocks(self, x, axis=-1):
|
||||
"""Split `x` into blocks matching `operators`'s `domain_dimension`.
|
||||
|
||||
Specifically, if we have a blockwise lower-triangular matrix, with block
|
||||
sizes along the diagonal `[M_j, M_j] j = 0,1,2..J`, this method splits `x`
|
||||
on `axis` into `J` tensors, whose shape at `axis` is `M_j`.
|
||||
|
||||
Args:
|
||||
x: `Tensor`. `x` is split into `J` tensors.
|
||||
axis: Python `Integer` representing the axis to split `x` on.
|
||||
|
||||
Returns:
|
||||
A list of `Tensor`s.
|
||||
"""
|
||||
block_sizes = []
|
||||
if self.shape.is_fully_defined():
|
||||
for row in self.operators:
|
||||
block_sizes.append(row[-1].domain_dimension.value)
|
||||
else:
|
||||
for row in self.operators:
|
||||
block_sizes.append(row[-1].domain_dimension_tensor())
|
||||
|
||||
return array_ops.split(x, block_sizes, axis=axis)
|
||||
|
@ -114,6 +114,10 @@ class LinearOperatorDerivedClassTest(test.TestCase):
|
||||
def use_placeholder_options():
|
||||
return [False, True]
|
||||
|
||||
@staticmethod
|
||||
def use_blockwise_arg():
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def operator_shapes_infos():
|
||||
"""Returns list of OperatorShapesInfo, encapsulating the shape to test."""
|
||||
@ -321,6 +325,7 @@ def _test_matmul_base(
|
||||
dtype,
|
||||
adjoint,
|
||||
adjoint_arg,
|
||||
blockwise_arg,
|
||||
with_batch):
|
||||
# If batch dimensions are omitted, but there are
|
||||
# no batch dimensions for the linear operator, then
|
||||
@ -346,8 +351,35 @@ def _test_matmul_base(
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(op_matmul.shape,
|
||||
mat_matmul.shape)
|
||||
op_matmul_v, mat_matmul_v = sess.run(
|
||||
[op_matmul, mat_matmul])
|
||||
|
||||
# If the operator is blockwise, test both blockwise `x` and `Tensor` `x`;
|
||||
# else test only `Tensor` `x`. In both cases, evaluate all results in a
|
||||
# single `sess.run` call to avoid re-sampling the random `x` in graph mode.
|
||||
if blockwise_arg and len(operator.operators) > 1:
|
||||
split_x = linear_operator_util.split_arg_into_blocks(
|
||||
operator._block_domain_dimensions(), # pylint: disable=protected-access
|
||||
operator._block_domain_dimension_tensors, # pylint: disable=protected-access
|
||||
x, axis=-2)
|
||||
if adjoint_arg:
|
||||
split_x = [linalg.adjoint(y) for y in split_x]
|
||||
split_matmul = operator.matmul(
|
||||
split_x, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
|
||||
self.assertEqual(len(split_matmul), len(operator.operators))
|
||||
split_matmul = linear_operator_util.broadcast_matrix_batch_dims(
|
||||
split_matmul)
|
||||
fused_block_matmul = array_ops.concat(split_matmul, axis=-2)
|
||||
op_matmul_v, mat_matmul_v, fused_block_matmul_v = sess.run([
|
||||
op_matmul, mat_matmul, fused_block_matmul])
|
||||
|
||||
# Check that the operator applied to blockwise input gives the same result
|
||||
# as matrix multiplication.
|
||||
self.assertAC(fused_block_matmul_v, mat_matmul_v)
|
||||
else:
|
||||
op_matmul_v, mat_matmul_v = sess.run([op_matmul, mat_matmul])
|
||||
|
||||
# Check that the operator applied to a `Tensor` gives the same result as
|
||||
# matrix multiplication.
|
||||
self.assertAC(op_matmul_v, mat_matmul_v)
|
||||
|
||||
|
||||
@ -356,7 +388,8 @@ def _test_matmul(
|
||||
shapes_info,
|
||||
dtype,
|
||||
adjoint,
|
||||
adjoint_arg):
|
||||
adjoint_arg,
|
||||
blockwise_arg):
|
||||
def test_matmul(self):
|
||||
_test_matmul_base(
|
||||
self,
|
||||
@ -365,6 +398,7 @@ def _test_matmul(
|
||||
dtype,
|
||||
adjoint,
|
||||
adjoint_arg,
|
||||
blockwise_arg,
|
||||
with_batch=True)
|
||||
return test_matmul
|
||||
|
||||
@ -374,7 +408,8 @@ def _test_matmul_with_broadcast(
|
||||
shapes_info,
|
||||
dtype,
|
||||
adjoint,
|
||||
adjoint_arg):
|
||||
adjoint_arg,
|
||||
blockwise_arg):
|
||||
def test_matmul_with_broadcast(self):
|
||||
_test_matmul_base(
|
||||
self,
|
||||
@ -383,6 +418,7 @@ def _test_matmul_with_broadcast(
|
||||
dtype,
|
||||
adjoint,
|
||||
adjoint_arg,
|
||||
blockwise_arg,
|
||||
with_batch=True)
|
||||
return test_matmul_with_broadcast
|
||||
|
||||
@ -505,6 +541,7 @@ def _test_solve_base(
|
||||
dtype,
|
||||
adjoint,
|
||||
adjoint_arg,
|
||||
blockwise_arg,
|
||||
with_batch):
|
||||
# If batch dimensions are omitted, but there are
|
||||
# no batch dimensions for the linear operator, then
|
||||
@ -532,12 +569,39 @@ def _test_solve_base(
|
||||
if not use_placeholder:
|
||||
self.assertAllEqual(op_solve.shape,
|
||||
mat_solve.shape)
|
||||
op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])
|
||||
|
||||
# If the operator is blockwise, test both blockwise rhs and `Tensor` rhs;
|
||||
# else test only `Tensor` rhs. In both cases, evaluate all results in a
|
||||
# single `sess.run` call to avoid re-sampling the random rhs in graph mode.
|
||||
if blockwise_arg and len(operator.operators) > 1:
|
||||
split_rhs = linear_operator_util.split_arg_into_blocks(
|
||||
operator._block_domain_dimensions(), # pylint: disable=protected-access
|
||||
operator._block_domain_dimension_tensors, # pylint: disable=protected-access
|
||||
rhs, axis=-2)
|
||||
if adjoint_arg:
|
||||
split_rhs = [linalg.adjoint(y) for y in split_rhs]
|
||||
split_solve = operator.solve(
|
||||
split_rhs, adjoint=adjoint, adjoint_arg=adjoint_arg)
|
||||
self.assertEqual(len(split_solve), len(operator.operators))
|
||||
split_solve = linear_operator_util.broadcast_matrix_batch_dims(
|
||||
split_solve)
|
||||
fused_block_solve = array_ops.concat(split_solve, axis=-2)
|
||||
op_solve_v, mat_solve_v, fused_block_solve_v = sess.run([
|
||||
op_solve, mat_solve, fused_block_solve])
|
||||
|
||||
# Check that the operator and matrix give the same solution when the rhs
|
||||
# is blockwise.
|
||||
self.assertAC(mat_solve_v, fused_block_solve_v)
|
||||
else:
|
||||
op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])
|
||||
|
||||
# Check that the operator and matrix give the same solution when the rhs is
|
||||
# a `Tensor`.
|
||||
self.assertAC(op_solve_v, mat_solve_v)
|
||||
|
||||
|
||||
def _test_solve(
|
||||
use_placeholder, shapes_info, dtype, adjoint, adjoint_arg):
|
||||
use_placeholder, shapes_info, dtype, adjoint, adjoint_arg, blockwise_arg):
|
||||
def test_solve(self):
|
||||
_test_solve_base(
|
||||
self,
|
||||
@ -546,12 +610,13 @@ def _test_solve(
|
||||
dtype,
|
||||
adjoint,
|
||||
adjoint_arg,
|
||||
blockwise_arg,
|
||||
with_batch=True)
|
||||
return test_solve
|
||||
|
||||
|
||||
def _test_solve_with_broadcast(
|
||||
use_placeholder, shapes_info, dtype, adjoint, adjoint_arg):
|
||||
use_placeholder, shapes_info, dtype, adjoint, adjoint_arg, blockwise_arg):
|
||||
def test_solve_with_broadcast(self):
|
||||
_test_solve_base(
|
||||
self,
|
||||
@ -560,6 +625,7 @@ def _test_solve_with_broadcast(
|
||||
dtype,
|
||||
adjoint,
|
||||
adjoint_arg,
|
||||
blockwise_arg,
|
||||
with_batch=False)
|
||||
return test_solve_with_broadcast
|
||||
|
||||
@ -681,7 +747,8 @@ def add_tests(test_cls):
|
||||
shape_info,
|
||||
dtype,
|
||||
adjoint,
|
||||
adjoint_arg)))
|
||||
adjoint_arg,
|
||||
test_cls.use_blockwise_arg())))
|
||||
else:
|
||||
if hasattr(test_cls, base_test_name):
|
||||
raise RuntimeError("Test %s defined more than once" % base_test_name)
|
||||
|
@ -29,6 +29,7 @@ from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import linalg_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variables as variables_module
|
||||
from tensorflow.python.util import nest
|
||||
|
||||
|
||||
################################################################################
|
||||
@ -135,6 +136,14 @@ def dtype_name(dtype):
|
||||
return str(dtype)
|
||||
|
||||
|
||||
def check_dtype(arg, dtype):
|
||||
"""Check that arg.dtype == self.dtype."""
|
||||
if arg.dtype.base_dtype != dtype:
|
||||
raise TypeError(
|
||||
"Expected argument to have dtype %s. Found: %s in tensor %s" %
|
||||
(dtype, arg.dtype, arg))
|
||||
|
||||
|
||||
def is_ref(x):
|
||||
"""Evaluates if the object has reference semantics.
|
||||
|
||||
@ -500,3 +509,77 @@ def use_operator_or_provided_hint_unless_contradicting(
|
||||
return False
|
||||
# pylint: enable=g-bool-id-comparison
|
||||
return None
|
||||
|
||||
|
||||
################################################################################
|
||||
# Utilities for blockwise operators.
|
||||
################################################################################
|
||||
|
||||
|
||||
def arg_is_blockwise(block_dimensions, arg, arg_split_dim):
|
||||
"""Detect if input should be interpreted as a list of blocks."""
|
||||
# Tuples and lists of length equal to the number of operators may be
|
||||
# blockwise.
|
||||
if (isinstance(arg, (tuple, list)) and len(arg) == len(block_dimensions)):
|
||||
# If the elements of the iterable are not nested, interpret the input as
|
||||
# blockwise.
|
||||
if not any(nest.is_nested(x) for x in arg):
|
||||
return True
|
||||
else:
|
||||
arg_dims = [ops.convert_to_tensor(x).shape[arg_split_dim] for x in arg]
|
||||
self_dims = [dim.value for dim in block_dimensions]
|
||||
|
||||
# If none of the operator dimensions are known, interpret the input as
|
||||
# blockwise if its matching dimensions are unequal.
|
||||
if all(self_d is None for self_d in self_dims):
|
||||
|
||||
# A nested tuple/list with a single outermost element is not blockwise
|
||||
if len(arg_dims) == 1:
|
||||
return False
|
||||
elif any(dim != arg_dims[0] for dim in arg_dims):
|
||||
return True
|
||||
else:
|
||||
raise ValueError(
|
||||
"Parsing of the input structure is ambiguous. Please input "
|
||||
"a blockwise iterable of `Tensor`s or a single `Tensor`.")
|
||||
|
||||
# If input dimensions equal the respective (known) blockwise operator
|
||||
# dimensions, then the input is blockwise.
|
||||
if all(self_d == arg_d or self_d is None
|
||||
for self_d, arg_d in zip(self_dims, arg_dims)):
|
||||
return True
|
||||
|
||||
# If input dimensions equals are all equal, and are greater than or equal
|
||||
# to the sum of the known operator dimensions, interpret the input as
|
||||
# blockwise.
|
||||
# input is not blockwise.
|
||||
self_dim = sum(self_d for self_d in self_dims if self_d is not None)
|
||||
if all(s == arg_dims[0] for s in arg_dims) and arg_dims[0] >= self_dim:
|
||||
return False
|
||||
|
||||
# If none of these conditions is met, the input shape is mismatched.
|
||||
raise ValueError("Input dimension does not match operator dimension.")
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def split_arg_into_blocks(block_dims, block_dims_fn, arg, axis=-1):
|
||||
"""Split `x` into blocks matching `operators`'s `domain_dimension`.
|
||||
|
||||
Specifically, if we have a blockwise lower-triangular matrix, with block
|
||||
sizes along the diagonal `[M_j, M_j] j = 0,1,2..J`, this method splits `arg`
|
||||
on `axis` into `J` tensors, whose shape at `axis` is `M_j`.
|
||||
|
||||
Args:
|
||||
block_dims: Iterable of `TensorShapes`.
|
||||
block_dims_fn: Callable returning an iterable of `Tensor`s.
|
||||
arg: `Tensor`. `arg` is split into `J` tensors.
|
||||
axis: Python `Integer` representing the axis to split `arg` on.
|
||||
|
||||
Returns:
|
||||
A list of `Tensor`s.
|
||||
"""
|
||||
block_sizes = [dim.value for dim in block_dims]
|
||||
if any(d is None for d in block_sizes):
|
||||
block_sizes = block_dims_fn()
|
||||
return array_ops.split(arg, block_sizes, axis=axis)
|
||||
|
Loading…
Reference in New Issue
Block a user