Remove all remaining references to matrix_triangular_solve_with_broadcast.

PiperOrigin-RevId: 292045093
Change-Id: Ide06b9345c7226c5e0797e44e2eaab878c047589
This commit is contained in:
Srinivas Vasudevan 2020-01-28 17:04:53 -08:00 committed by TensorFlower Gardener
parent 164f87fd88
commit 032d74b252
2 changed files with 0 additions and 119 deletions

View File

@ -273,74 +273,6 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
self.assertAllClose(expected, result)
class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
def test_static_dims_broadcast_matrix_has_extra_dims(self):
# batch_shape = [2]
matrix = rng.rand(2, 3, 3)
rhs = rng.rand(3, 7)
rhs_broadcast = rhs + np.zeros((2, 1, 1))
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs)
self.assertAllEqual((2, 3, 7), result.shape)
expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
self.assertAllClose(*self.evaluate([expected, result]))
def test_static_dims_broadcast_rhs_has_extra_dims(self):
# Since the second arg has extra dims, and the domain dim of the first arg
# is larger than the number of linear equations, code will "flip" the extra
# dims of the first arg to the far right, making extra linear equations
# (then call the matrix function, then flip back).
# We have verified that this optimization indeed happens. How? We stepped
# through with a debugger.
# batch_shape = [2]
matrix = rng.rand(3, 3)
rhs = rng.rand(2, 3, 2)
matrix_broadcast = matrix + np.zeros((2, 1, 1))
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs)
self.assertAllEqual((2, 3, 2), result.shape)
expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs)
self.assertAllClose(*self.evaluate([expected, result]))
def test_static_dims_broadcast_rhs_has_extra_dims_and_adjoint(self):
# Since the second arg has extra dims, and the domain dim of the first arg
# is larger than the number of linear equations, code will "flip" the extra
# dims of the first arg to the far right, making extra linear equations
# (then call the matrix function, then flip back).
# We have verified that this optimization indeed happens. How? We stepped
# through with a debugger.
# batch_shape = [2]
matrix = rng.rand(3, 3)
rhs = rng.rand(2, 3, 2)
matrix_broadcast = matrix + np.zeros((2, 1, 1))
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix, rhs, adjoint=True)
self.assertAllEqual((2, 3, 2), result.shape)
expected = linalg_ops.matrix_triangular_solve(
matrix_broadcast, rhs, adjoint=True)
self.assertAllClose(*self.evaluate([expected, result]))
def test_dynamic_dims_broadcast_64bit(self):
# batch_shape = [2]
matrix = rng.rand(2, 3, 3)
rhs = rng.rand(3, 7)
rhs_broadcast = rhs + np.zeros((2, 1, 1))
matrix_ph = array_ops.placeholder_with_default(matrix, shape=None)
rhs_ph = array_ops.placeholder_with_default(rhs, shape=None)
result, expected = self.evaluate([
linear_operator_util.matrix_triangular_solve_with_broadcast(
matrix_ph, rhs_ph),
linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
])
self.assertAllClose(expected, result)
class DomainDimensionStubOperator(object):
def __init__(self, domain_dimension):

View File

@ -373,57 +373,6 @@ def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
return reshape_inv(solution)
def matrix_triangular_solve_with_broadcast(matrix,
rhs,
lower=True,
adjoint=False,
name=None):
"""Solves triangular systems of linear equations with by backsubstitution.
Works identically to `tf.linalg.triangular_solve`, but broadcasts batch dims
of `matrix` and `rhs` (by replicating) if they are determined statically to be
different, or if static shapes are not fully defined. Thus, this may result
in an inefficient replication of data.
Args:
matrix: A Tensor. Must be one of the following types:
`float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`.
rhs: A `Tensor`. Must have the same `dtype` as `matrix`.
Shape is `[..., M, K]`.
lower: An optional `bool`. Defaults to `True`. Indicates whether the
innermost matrices in `matrix` are lower or upper triangular.
adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve
with matrix or its (block-wise) adjoint.
name: A name for the operation (optional).
Returns:
`Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`.
"""
with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]):
matrix = ops.convert_to_tensor(matrix, name="matrix")
rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype)
# If either matrix/rhs has extra dims, we can reshape to get rid of them.
matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
matrix, rhs, adjoint_a=adjoint)
# lower indicates whether the matrix is lower triangular. If we have
# manually taken adjoint inside _reshape_for_efficiency, it is now upper tri
if not still_need_to_transpose and adjoint:
lower = not lower
# This will broadcast by brute force if we still need to.
matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
solution = linalg_ops.matrix_triangular_solve(
matrix,
rhs,
lower=lower,
adjoint=adjoint and still_need_to_transpose)
return reshape_inv(solution)
def _reshape_for_efficiency(a,
b,
transpose_a=False,