Remove all remaining references to matrix_triangular_solve_with_broadcast.
PiperOrigin-RevId: 292045093 Change-Id: Ide06b9345c7226c5e0797e44e2eaab878c047589
This commit is contained in:
parent
164f87fd88
commit
032d74b252
@ -273,74 +273,6 @@ class MatrixSolveWithBroadcastTest(test.TestCase):
|
||||
self.assertAllClose(expected, result)
|
||||
|
||||
|
||||
class MatrixTriangularSolveWithBroadcastTest(test.TestCase):
|
||||
|
||||
def test_static_dims_broadcast_matrix_has_extra_dims(self):
|
||||
# batch_shape = [2]
|
||||
matrix = rng.rand(2, 3, 3)
|
||||
rhs = rng.rand(3, 7)
|
||||
rhs_broadcast = rhs + np.zeros((2, 1, 1))
|
||||
|
||||
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 7), result.shape)
|
||||
expected = linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
def test_static_dims_broadcast_rhs_has_extra_dims(self):
|
||||
# Since the second arg has extra dims, and the domain dim of the first arg
|
||||
# is larger than the number of linear equations, code will "flip" the extra
|
||||
# dims of the first arg to the far right, making extra linear equations
|
||||
# (then call the matrix function, then flip back).
|
||||
# We have verified that this optimization indeed happens. How? We stepped
|
||||
# through with a debugger.
|
||||
# batch_shape = [2]
|
||||
matrix = rng.rand(3, 3)
|
||||
rhs = rng.rand(2, 3, 2)
|
||||
matrix_broadcast = matrix + np.zeros((2, 1, 1))
|
||||
|
||||
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix, rhs)
|
||||
self.assertAllEqual((2, 3, 2), result.shape)
|
||||
expected = linalg_ops.matrix_triangular_solve(matrix_broadcast, rhs)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
def test_static_dims_broadcast_rhs_has_extra_dims_and_adjoint(self):
|
||||
# Since the second arg has extra dims, and the domain dim of the first arg
|
||||
# is larger than the number of linear equations, code will "flip" the extra
|
||||
# dims of the first arg to the far right, making extra linear equations
|
||||
# (then call the matrix function, then flip back).
|
||||
# We have verified that this optimization indeed happens. How? We stepped
|
||||
# through with a debugger.
|
||||
# batch_shape = [2]
|
||||
matrix = rng.rand(3, 3)
|
||||
rhs = rng.rand(2, 3, 2)
|
||||
matrix_broadcast = matrix + np.zeros((2, 1, 1))
|
||||
|
||||
result = linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix, rhs, adjoint=True)
|
||||
self.assertAllEqual((2, 3, 2), result.shape)
|
||||
expected = linalg_ops.matrix_triangular_solve(
|
||||
matrix_broadcast, rhs, adjoint=True)
|
||||
self.assertAllClose(*self.evaluate([expected, result]))
|
||||
|
||||
def test_dynamic_dims_broadcast_64bit(self):
|
||||
# batch_shape = [2]
|
||||
matrix = rng.rand(2, 3, 3)
|
||||
rhs = rng.rand(3, 7)
|
||||
rhs_broadcast = rhs + np.zeros((2, 1, 1))
|
||||
|
||||
matrix_ph = array_ops.placeholder_with_default(matrix, shape=None)
|
||||
rhs_ph = array_ops.placeholder_with_default(rhs, shape=None)
|
||||
|
||||
result, expected = self.evaluate([
|
||||
linear_operator_util.matrix_triangular_solve_with_broadcast(
|
||||
matrix_ph, rhs_ph),
|
||||
linalg_ops.matrix_triangular_solve(matrix, rhs_broadcast)
|
||||
])
|
||||
self.assertAllClose(expected, result)
|
||||
|
||||
|
||||
class DomainDimensionStubOperator(object):
|
||||
|
||||
def __init__(self, domain_dimension):
|
||||
|
@ -373,57 +373,6 @@ def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None):
|
||||
return reshape_inv(solution)
|
||||
|
||||
|
||||
def matrix_triangular_solve_with_broadcast(matrix,
|
||||
rhs,
|
||||
lower=True,
|
||||
adjoint=False,
|
||||
name=None):
|
||||
"""Solves triangular systems of linear equations with by backsubstitution.
|
||||
|
||||
Works identically to `tf.linalg.triangular_solve`, but broadcasts batch dims
|
||||
of `matrix` and `rhs` (by replicating) if they are determined statically to be
|
||||
different, or if static shapes are not fully defined. Thus, this may result
|
||||
in an inefficient replication of data.
|
||||
|
||||
Args:
|
||||
matrix: A Tensor. Must be one of the following types:
|
||||
`float64`, `float32`, `complex64`, `complex128`. Shape is `[..., M, M]`.
|
||||
rhs: A `Tensor`. Must have the same `dtype` as `matrix`.
|
||||
Shape is `[..., M, K]`.
|
||||
lower: An optional `bool`. Defaults to `True`. Indicates whether the
|
||||
innermost matrices in `matrix` are lower or upper triangular.
|
||||
adjoint: An optional `bool`. Defaults to `False`. Indicates whether to solve
|
||||
with matrix or its (block-wise) adjoint.
|
||||
name: A name for the operation (optional).
|
||||
|
||||
Returns:
|
||||
`Tensor` with same `dtype` as `matrix` and shape `[..., M, K]`.
|
||||
"""
|
||||
with ops.name_scope(name, "MatrixTriangularSolve", [matrix, rhs]):
|
||||
matrix = ops.convert_to_tensor(matrix, name="matrix")
|
||||
rhs = ops.convert_to_tensor(rhs, name="rhs", dtype=matrix.dtype)
|
||||
|
||||
# If either matrix/rhs has extra dims, we can reshape to get rid of them.
|
||||
matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency(
|
||||
matrix, rhs, adjoint_a=adjoint)
|
||||
|
||||
# lower indicates whether the matrix is lower triangular. If we have
|
||||
# manually taken adjoint inside _reshape_for_efficiency, it is now upper tri
|
||||
if not still_need_to_transpose and adjoint:
|
||||
lower = not lower
|
||||
|
||||
# This will broadcast by brute force if we still need to.
|
||||
matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs])
|
||||
|
||||
solution = linalg_ops.matrix_triangular_solve(
|
||||
matrix,
|
||||
rhs,
|
||||
lower=lower,
|
||||
adjoint=adjoint and still_need_to_transpose)
|
||||
|
||||
return reshape_inv(solution)
|
||||
|
||||
|
||||
def _reshape_for_efficiency(a,
|
||||
b,
|
||||
transpose_a=False,
|
||||
|
Loading…
Reference in New Issue
Block a user