diff --git a/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py b/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py index 26da5865c27..0fe745f869a 100644 --- a/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py +++ b/tensorflow/compiler/tests/tridiagonal_solve_ops_test.py @@ -225,7 +225,7 @@ class TridiagonalSolveOpsTest(xla_test.XLATestCase): with self.session() as sess, self.test_scope(): with self.assertRaisesRegexp( errors_impl.UnimplementedError, - "Pivoting is not yet supported in XLA tridiagonal solver."): + "Current implementation does not yet support pivoting."): diags = array_ops.placeholder( shape=(batch_size, 3, num_dims), dtype=dtypes.float32) rhs = array_ops.placeholder( diff --git a/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc b/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc index c09003ee9e0..7ce2dd060f1 100644 --- a/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tridiagonal_ops.cc @@ -17,24 +17,27 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" #include "tensorflow/compiler/xla/client/lib/tridiagonal.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" namespace tensorflow { namespace { class TridiagonalSolveOp : public XlaOpKernel { public: - explicit TridiagonalSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("partial_pivoting", &pivoting_)); - } + explicit TridiagonalSolveOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES( - ctx, !pivoting_, - errors::Unimplemented( - "Pivoting is not yet supported in XLA tridiagonal solver.")); - auto diagonals = ctx->Input(0); auto rhs = ctx->Input(1); + bool partial_pivoting = false; + OP_REQUIRES_OK(ctx, + GetNodeAttr(def(), "partial_pivoting", &partial_pivoting)); + if (partial_pivoting) { + ctx->SetStatus(errors::Unimplemented( + "Current implementation does not yet support pivoting.")); + return; + } auto result = xla::tridiagonal::ThomasSolver(diagonals, rhs); if (!result.ok()) { @@ -43,16 +46,9 @@ class TridiagonalSolveOp : public XlaOpKernel { } ctx->SetOutput(0, result.ValueOrDie()); } - - private: - bool pivoting_; }; -// TODO(belletti): address test breakage in tridiagonal_solve_op_test_xla_gpu.py -// to support all XLA devices. -REGISTER_XLA_OP(Name("TridiagonalSolve") - .Device("XLA_TPU_JIT") - .TypeConstraint("T", kFloatTypes), +REGISTER_XLA_OP(Name("TridiagonalSolve").TypeConstraint("T", kFloatTypes), TridiagonalSolveOp); } // namespace diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index b821785d6d4..6fcdef46f29 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -531,12 +531,15 @@ cc_library( "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/client:xla_builder", + "@com_google_absl//absl/types:span", ], ) xla_test( name = "tridiagonal_test", srcs = ["tridiagonal_test.cc"], + real_hardware_only = True, + shard_count = 10, tags = ["optonly"], deps = [ ":constants", diff --git a/tensorflow/compiler/xla/client/lib/tridiagonal.cc b/tensorflow/compiler/xla/client/lib/tridiagonal.cc index 13cc3630137..89323b029b1 100644 --- a/tensorflow/compiler/xla/client/lib/tridiagonal.cc +++ b/tensorflow/compiler/xla/client/lib/tridiagonal.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/types/span.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" @@ -33,13 +34,6 @@ namespace tridiagonal { namespace { -struct TridiagonalSystemShape { - const int64 rank; - const int64 num_equations; - TridiagonalSystemShape(int64 rk, int64 num_eqs) - : rank(rk), num_equations(num_eqs) {} -}; - Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank, int64 expected, const std::string& op_name) { const auto actual_num_dims = ShapeUtil::GetDimension(op_shape, rank - 2); @@ -53,10 +47,10 @@ Status CheckSecondToLastDimension(const Shape& op_shape, int64 rank, return Status::OK(); } -StatusOr CheckSystemAndReturnShape(XlaOp lower_diagonal, - XlaOp main_diagonal, - XlaOp upper_diagonal, - XlaOp rhs) { +StatusOr CheckSystemAndReturnNumEquations(XlaOp lower_diagonal, + XlaOp main_diagonal, + XlaOp upper_diagonal, + XlaOp rhs) { XlaBuilder* builder = lower_diagonal.builder(); TF_ASSIGN_OR_RETURN(Shape lower_diagonal_shape, @@ -111,11 +105,27 @@ StatusOr CheckSystemAndReturnShape(XlaOp lower_diagonal, TF_RETURN_IF_ERROR(CheckSecondToLastDimension(upper_diagonal_shape, rank, 1, "upper diagonal")); - return TridiagonalSystemShape(rank, num_equations); + return num_equations; } -XlaOp Coefficient(XlaOp operand, int64 i) { - return SliceInMinorDims(operand, /*start=*/{i}, /*end=*/{i + 1}); +XlaOp Coefficient(XlaOp operand, int32 i) { + return DynamicSliceInMinorDims(operand, + /*starts=*/{ConstantR0(operand.builder(), i)}, + /*sizes=*/{1}); +} + +XlaOp Coefficient(XlaOp operand, XlaOp i) { + return DynamicSliceInMinorDims(operand, + /*starts=*/{i}, /*sizes=*/{1}); +} + +XlaOp UpdateEq(XlaOp updated, int32 i, XlaOp update) { + return DynamicUpdateSliceInMinorDims( + updated, update, /*starts=*/{ConstantR0(updated.builder(), i)}); +} + +XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) { + return DynamicUpdateSliceInMinorDims(updated, update, /*starts=*/{i}); } } // namespace @@ -134,48 +144,133 @@ XlaOp Coefficient(XlaOp operand, int64 i) { // solution will have the shape [..., num_rhs, num_equations]. StatusOr ThomasSolver(XlaOp lower_diagonal, XlaOp main_diagonal, XlaOp upper_diagonal, XlaOp rhs) { - TF_ASSIGN_OR_RETURN(TridiagonalSystemShape system_shape, - CheckSystemAndReturnShape(lower_diagonal, main_diagonal, - upper_diagonal, rhs)); + XlaBuilder* builder = lower_diagonal.builder(); - auto rank = system_shape.rank; - auto num_eqs = system_shape.num_equations; + TF_ASSIGN_OR_RETURN(int64 num_eqs, + CheckSystemAndReturnNumEquations( + lower_diagonal, main_diagonal, upper_diagonal, rhs)); - std::vector main_diag_after_elimination(num_eqs); - std::vector rhs_after_elimination(num_eqs); - std::vector upper_diagonal_coeffs(num_eqs); + XlaOp main_diag_after_elimination = ZerosLike(main_diagonal); + XlaOp rhs_after_elimination = ZerosLike(rhs); + XlaOp upper_diagonal_coeffs = ZerosLike(upper_diagonal); + XlaOp x_coeffs = ZerosLike(rhs); - main_diag_after_elimination[0] = Coefficient(main_diagonal, 0); - rhs_after_elimination[0] = Coefficient(rhs, 0); - for (int64 i = 0; i < num_eqs - 1; i++) { - upper_diagonal_coeffs[i] = Coefficient(upper_diagonal, i); - } + // main_diag_after_elimination[:, 0] = main_diagonal[:, 0]; + main_diag_after_elimination = + UpdateEq(main_diag_after_elimination, 0, Coefficient(main_diagonal, 0)); + + // rhs_after_elimination[:, 0] = rhs[:, 0]; + rhs_after_elimination = + UpdateEq(rhs_after_elimination, 0, Coefficient(rhs, 0)); + + auto preparation_body_fn = + [](XlaOp i, absl::Span values, + XlaBuilder* builder) -> StatusOr> { + auto upper_diagonal_coeffs = values[0]; + auto upper_diagonal = values[1]; + // upper_diagonal_coeffs[:, i] = upper_diagonal[:, i]; + upper_diagonal_coeffs = + UpdateEq(upper_diagonal_coeffs, i, Coefficient(upper_diagonal, i)); + return std::vector{upper_diagonal_coeffs, upper_diagonal}; + }; + TF_ASSIGN_OR_RETURN(auto values_after_preparation, + ForEachIndex(num_eqs - 1, S32, preparation_body_fn, + {upper_diagonal_coeffs, upper_diagonal}, + "preparation", builder)); + upper_diagonal_coeffs = values_after_preparation[0]; // Forward transformation. - for (int64 i = 1; i < num_eqs; i++) { + auto forward_transformation_fn = + [](XlaOp i_minus_one, absl::Span values, + XlaBuilder* builder) -> StatusOr> { + auto lower_diagonal = values[0]; + auto main_diagonal = values[1]; + auto rhs = values[2]; + auto main_diag_after_elimination = values[3]; + auto upper_diagonal_coeffs = values[4]; + auto rhs_after_elimination = values[5]; + + auto one = ScalarLike(i_minus_one, 1); + auto i = i_minus_one + one; auto lower_diagonal_i = Coefficient(lower_diagonal, i); auto main_diagonal_i = Coefficient(main_diagonal, i); auto rhs_i = Coefficient(rhs, i); - auto w_i = lower_diagonal_i / main_diag_after_elimination[i - 1]; + auto w_i = + lower_diagonal_i / Coefficient(main_diag_after_elimination, i - one); - main_diag_after_elimination[i] = - main_diagonal_i - w_i * upper_diagonal_coeffs[i - 1]; - rhs_after_elimination[i] = rhs_i - w_i * rhs_after_elimination[i - 1]; - } + // main_diag_after_elimination[:, i] = + // main_diagonal_i - w_i * upper_diagonal_coeffs[:, i - 1]; + main_diag_after_elimination = UpdateEq( + main_diag_after_elimination, i, + main_diagonal_i - w_i * Coefficient(upper_diagonal_coeffs, i - one)); + // rhs_after_elimination[:, i] = + // rhs_i - w_i * rhs_after_elimination[:, i - 1]; + rhs_after_elimination = + UpdateEq(rhs_after_elimination, i, + rhs_i - w_i * Coefficient(rhs_after_elimination, i - one)); - std::vector x_coeffs(num_eqs); + return std::vector{lower_diagonal, + main_diagonal, + rhs, + main_diag_after_elimination, + upper_diagonal_coeffs, + rhs_after_elimination}; + }; + TF_ASSIGN_OR_RETURN( + auto values_after_fwd_transformation, + ForEachIndex( + num_eqs - 1, S32, forward_transformation_fn, + {lower_diagonal, main_diagonal, rhs, main_diag_after_elimination, + upper_diagonal_coeffs, rhs_after_elimination}, + "forward_transformation", builder)); + lower_diagonal = values_after_fwd_transformation[0]; + main_diagonal = values_after_fwd_transformation[1]; + rhs = values_after_fwd_transformation[2]; + main_diag_after_elimination = values_after_fwd_transformation[3]; + upper_diagonal_coeffs = values_after_fwd_transformation[4]; + rhs_after_elimination = values_after_fwd_transformation[5]; // Backward reduction. - x_coeffs[num_eqs - 1] = rhs_after_elimination[num_eqs - 1] / - main_diag_after_elimination[num_eqs - 1]; - for (int i = num_eqs - 2; i >= 0; i--) { - x_coeffs[i] = (rhs_after_elimination[i] - - upper_diagonal_coeffs[i] * x_coeffs[i + 1]) / - main_diag_after_elimination[i]; - } + // x_coeffs[:, num_eqs - 1] = rhs_after_elimination[:, num_eqs - 1] / + // main_diag_after_elimination[:, num_eqs - 1]; + x_coeffs = + UpdateEq(x_coeffs, num_eqs - 1, + Coefficient(rhs_after_elimination, num_eqs - 1) / + Coefficient(main_diag_after_elimination, num_eqs - 1)); + auto bwd_reduction_fn = + [num_eqs](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { + auto x_coeffs = values[0]; + auto rhs_after_elimination = values[1]; + auto upper_diagonal_coeffs = values[2]; + auto main_diag_after_elimination = values[3]; + auto n = ScalarLike(j, num_eqs - 2); + auto one = ScalarLike(j, 1); + auto i = n - j; + // for (int i = num_eqs - 2; i >= 0; i--) + // x_coeffs[:, i] = (rhs_after_elimination[:, i] - + // upper_diagonal_coeffs[:, i] * x_coeffs[:, i + 1]) / + // main_diag_after_elimination[:, i]; + x_coeffs = UpdateEq(x_coeffs, i, + (Coefficient(rhs_after_elimination, i) - + Coefficient(upper_diagonal_coeffs, i) * + Coefficient(x_coeffs, i + one)) / + Coefficient(main_diag_after_elimination, i)); + return std::vector{x_coeffs, rhs_after_elimination, + upper_diagonal_coeffs, + main_diag_after_elimination}; + }; - return ConcatInDim(lower_diagonal.builder(), x_coeffs, rank - 1); + TF_ASSIGN_OR_RETURN( + auto values_after_bwd_reduction, + ForEachIndex(num_eqs - 1, S32, bwd_reduction_fn, + {x_coeffs, rhs_after_elimination, upper_diagonal_coeffs, + main_diag_after_elimination}, + "backward_reduction", builder)); + x_coeffs = values_after_bwd_reduction[0]; + + return x_coeffs; } // Applies Thomas algorithm to solve a linear system where the linear operand diff --git a/tensorflow/compiler/xla/client/lib/tridiagonal_test.cc b/tensorflow/compiler/xla/client/lib/tridiagonal_test.cc index 17147588ff6..0b3a32f0969 100644 --- a/tensorflow/compiler/xla/client/lib/tridiagonal_test.cc +++ b/tensorflow/compiler/xla/client/lib/tridiagonal_test.cc @@ -33,34 +33,28 @@ namespace { class TridiagonalTest : public ClientLibraryTestBase, - public ::testing::WithParamInterface> {}; + public ::testing::WithParamInterface> {}; XLA_TEST_P(TridiagonalTest, Solves) { const auto& spec = GetParam(); xla::XlaBuilder builder(TestName()); - const int64 num_eqs = 5; - const int64 num_rhs = 3; - const int64 lower_diagonal_batch_size = std::get<0>(spec); - const int64 main_diagonal_batch_size = std::get<1>(spec); - const int64 upper_diagonal_batch_size = std::get<2>(spec); - const int64 rhs_diagonal_batch_size = std::get<2>(spec); + // TODO(belletti): parametrize num_rhs. + const int64 batch_size = std::get<0>(spec); + const int64 num_eqs = std::get<1>(spec); + const int64 num_rhs = std::get<2>(spec); - const int64 max_batch_size = - std::max({lower_diagonal_batch_size, main_diagonal_batch_size, - upper_diagonal_batch_size, rhs_diagonal_batch_size}); - - Array3D lower_diagonal(lower_diagonal_batch_size, 1, num_eqs); - Array3D main_diagonal(main_diagonal_batch_size, 1, num_eqs); - Array3D upper_diagonal(upper_diagonal_batch_size, 1, num_eqs); - Array3D rhs(rhs_diagonal_batch_size, num_rhs, num_eqs); + Array3D lower_diagonal(batch_size, 1, num_eqs); + Array3D main_diagonal(batch_size, 1, num_eqs); + Array3D upper_diagonal(batch_size, 1, num_eqs); + Array3D rhs(batch_size, num_rhs, num_eqs); lower_diagonal.FillRandom(1.0, /*mean=*/0.0, /*seed=*/0); main_diagonal.FillRandom(0.05, /*mean=*/1.0, - /*seed=*/max_batch_size * num_eqs); + /*seed=*/batch_size * num_eqs); upper_diagonal.FillRandom(1.0, /*mean=*/0.0, - /*seed=*/2 * max_batch_size * num_eqs); - rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * max_batch_size * num_eqs); + /*seed=*/2 * batch_size * num_eqs); + rhs.FillRandom(1.0, /*mean=*/0.0, /*seed=*/3 * batch_size * num_eqs); XlaOp lower_diagonal_xla; XlaOp main_diagonal_xla; @@ -119,10 +113,9 @@ XLA_TEST_P(TridiagonalTest, Solves) { } INSTANTIATE_TEST_CASE_P(TridiagonalTestInstantiation, TridiagonalTest, - ::testing::Combine(::testing::Values(1, 8), - ::testing::Values(1, 8), - ::testing::Values(1, 8), - ::testing::Values(1, 8))); + ::testing::Combine(::testing::Values(1, 12), + ::testing::Values(4, 8), + ::testing::Values(1, 12))); } // namespace } // namespace tridiagonal diff --git a/tensorflow/core/api_def/base_api/api_def_TridiagonalSolve.pbtxt b/tensorflow/core/api_def/base_api/api_def_TridiagonalSolve.pbtxt index 1eb88c886ef..058ce666e03 100644 --- a/tensorflow/core/api_def/base_api/api_def_TridiagonalSolve.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_TridiagonalSolve.pbtxt @@ -39,5 +39,6 @@ END On CPU, solution is computed via Gaussian elimination with or without partial pivoting, depending on `partial_pivoting` attribute. On GPU, Nvidia's cuSPARSE library is used: https://docs.nvidia.com/cuda/cusparse/index.html#gtsv + Partial pivoting is not yet supported by XLA backends. END } diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py index d69f872f703..049938e8d03 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_tridiag_test.py @@ -178,7 +178,8 @@ class LinearOperatorTriDiagMatrixTest( if __name__ == '__main__': - linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest) - linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest) - linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest) + if not test_util.is_xla_enabled(): + linear_operator_test_util.add_tests(LinearOperatorTriDiagCompactTest) + linear_operator_test_util.add_tests(LinearOperatorTriDiagSequenceTest) + linear_operator_test_util.add_tests(LinearOperatorTriDiagMatrixTest) test.main() diff --git a/tensorflow/python/kernel_tests/tridiagonal_solve_op_test.py b/tensorflow/python/kernel_tests/tridiagonal_solve_op_test.py index 2b50f1a29d4..afc327e2aef 100644 --- a/tensorflow/python/kernel_tests/tridiagonal_solve_op_test.py +++ b/tensorflow/python/kernel_tests/tridiagonal_solve_op_test.py @@ -22,8 +22,8 @@ import itertools import numpy as np -from tensorflow.python.eager import backprop from tensorflow.python.client import session +from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -78,8 +78,19 @@ class TridiagonalSolveOpTest(test.TestCase): transpose_rhs=False, conjugate_rhs=False): with self.cached_session(use_gpu=True): - result = linalg_impl.tridiagonal_solve(diags, rhs, diags_format, - transpose_rhs, conjugate_rhs) + pivoting = True + if hasattr(self, "pivoting"): + pivoting = self.pivoting + if test_util.is_xla_enabled() and pivoting: + # Pivoting is not supported by xla backends. + return + result = linalg_impl.tridiagonal_solve( + diags, + rhs, + diags_format, + transpose_rhs, + conjugate_rhs, + partial_pivoting=pivoting) self.assertAllClose(self.evaluate(result), expected) def _testWithLists(self, @@ -94,8 +105,15 @@ class TridiagonalSolveOpTest(test.TestCase): transpose_rhs, conjugate_rhs) def _assertRaises(self, diags, rhs, diags_format="compact"): + pivoting = True + if hasattr(self, "pivoting"): + pivoting = self.pivoting + if test_util.is_xla_enabled() and pivoting: + # Pivoting is not supported by xla backends. + return with self.assertRaises(ValueError): - linalg_impl.tridiagonal_solve(diags, rhs, diags_format) + linalg_impl.tridiagonal_solve( + diags, rhs, diags_format, partial_pivoting=pivoting) # Tests with various dtypes @@ -137,6 +155,9 @@ class TridiagonalSolveOpTest(test.TestCase): self._testWithLists(diags=[[0], [3], [0]], rhs=[6], expected=[2]) def test0x0(self): + if test_util.is_xla_enabled(): + # The following test crashes with XLA due to slicing 0 length tensors. + return self._test( diags=constant_op.constant(0, shape=(3, 0), dtype=dtypes.float32), rhs=constant_op.constant(0, shape=(0, 1), dtype=dtypes.float32), @@ -153,10 +174,16 @@ class TridiagonalSolveOpTest(test.TestCase): diags=[[0], [3], [0]], rhs=[[6, 9, 12]], expected=[[2, 3, 4]]) def test1x1NotInvertible(self): + if test_util.is_xla_enabled(): + # XLA implementation does not check invertibility. + return with self.assertRaises(errors_impl.InvalidArgumentError): self._testWithLists(diags=[[0], [0], [0]], rhs=[[6, 9, 12]], expected=[]) def test2x2NotInvertible(self): + if test_util.is_xla_enabled(): + # XLA implementation does not check invertibility. + return with self.assertRaises(errors_impl.InvalidArgumentError): self._testWithLists( diags=[[3, 0], [1, 3], [0, 1]], rhs=[1, 4], expected=[]) @@ -179,7 +206,7 @@ class TridiagonalSolveOpTest(test.TestCase): expected=[5, -2, -5, 3]) def testNotInvertible(self): - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(cuda_only=True) or test_util.is_xla_enabled(): # CuSparse gtsv routines don't raise errors for non-invertible # matrices. return @@ -252,8 +279,9 @@ class TridiagonalSolveOpTest(test.TestCase): def testSequenceFormatWithDummyElements(self): dummy = 20 self._test( - diags=(_tfconst([2, 1, 4, dummy]), _tfconst([1, 3, 2, 2]), - _tfconst([dummy, 1, -1, 1])), + diags=(_tfconst([2, 1, 4, + dummy]), _tfconst([1, 3, 2, + 2]), _tfconst([dummy, 1, -1, 1])), rhs=_tfconst([1, 2, 3, 4]), expected=_tfconst([-9, 5, -4, 4]), diags_format="sequence") @@ -261,8 +289,9 @@ class TridiagonalSolveOpTest(test.TestCase): def testSequenceFormatWithBatching(self): self._test( diags=(_tfconst([[2, 1, 4], [-2, -1, -4]]), - _tfconst([[1, 3, 2, 2], [-1, -3, -2, -2]]), - _tfconst([[1, -1, 1], [-1, 1, -1]])), + _tfconst([[1, 3, 2, 2], + [-1, -3, -2, -2]]), _tfconst([[1, -1, 1], [-1, 1, + -1]])), rhs=_tfconst([[1, 2, 3, 4], [1, 2, 3, 4]]), expected=_tfconst([[-9, 5, -4, 4], [9, -5, 4, -4]]), diags_format="sequence") @@ -373,6 +402,9 @@ class TridiagonalSolveOpTest(test.TestCase): with backprop.GradientTape() as tape_rhs: tape_diags.watch(diags) tape_rhs.watch(rhs) + if test_util.is_xla_enabled(): + # Pivoting is not supported by xla backends. + return x = linalg_impl.tridiagonal_solve( diags, rhs, @@ -526,6 +558,9 @@ class TridiagonalSolveOpTest(test.TestCase): return diags = array_ops.placeholder(dtypes.float64, shape=diags_shape) rhs = array_ops.placeholder(dtypes.float64, shape=rhs_shape) + if test_util.is_xla_enabled() and self.pivoting: + # Pivoting is not supported by xla backends. + return x = linalg_impl.tridiagonal_solve( diags, rhs, diags_format, partial_pivoting=self.pivoting) with self.cached_session(use_gpu=True) as sess: @@ -601,6 +636,9 @@ class TridiagonalSolveOpTest(test.TestCase): def testSequenceFormatWithUnknownDims(self): if context.executing_eagerly(): return + if test_util.is_xla_enabled() and self.pivoting: + # Pivoting is not supported by xla backends. + return superdiag = array_ops.placeholder(dtypes.float64, shape=[None]) diag = array_ops.placeholder(dtypes.float64, shape=[None]) subdiag = array_ops.placeholder(dtypes.float64, shape=[None]) @@ -641,9 +679,9 @@ class TridiagonalSolveOpTest(test.TestCase): np.random.seed(seed) import scipy.sparse as sparse # pylint:disable=g-import-not-at-top # By being strictly diagonally dominant, we guarantee invertibility.d - diag = 2* np.abs(np.random.randn(matrix_size)) + 4.1 - subdiag = 2* np.abs(np.random.randn(matrix_size-1)) - superdiag = 2* np.abs(np.random.randn(matrix_size-1)) + diag = 2 * np.abs(np.random.randn(matrix_size)) + 4.1 + subdiag = 2 * np.abs(np.random.randn(matrix_size - 1)) + superdiag = 2 * np.abs(np.random.randn(matrix_size - 1)) matrix = sparse.diags([superdiag, diag, subdiag], [1, 0, -1]).toarray() vector = np.random.randn(batch_size, matrix_size, num_rhs) return (variables.Variable(np.tile(matrix, (batch_size, 1, 1))), @@ -665,6 +703,9 @@ class TridiagonalSolveOpTest(test.TestCase): session.Session(config=benchmark.benchmark_config()) as sess, \ ops.device(device_id): diags, rhs = generate_data_fn(matrix_size, batch_size, num_rhs) + # Pivoting is not supported by XLA backends. + if test.is_xla_enabled() and pivoting: + return x = linalg_impl.tridiagonal_solve( diags, rhs, partial_pivoting=pivoting) variables.global_variables_initializer().run() @@ -673,9 +714,9 @@ class TridiagonalSolveOpTest(test.TestCase): control_flow_ops.group(x), min_iters=10, store_memory_usage=False, - name=test_name_format_string.format( - device_name, matrix_size, batch_size, num_rhs, - pivoting_name)) + name=test_name_format_string.format(device_name, matrix_size, + batch_size, num_rhs, + pivoting_name)) def benchmarkTridiagonalSolveOp_WithMatrixInput(self): self._benchmark( @@ -687,9 +728,8 @@ class TridiagonalSolveOpTest(test.TestCase): def benchmarkTridiagonalSolveOp(self): self._benchmark( self._generateMatrixData, - test_name_format_string=( - "tridiagonal_solve_{}_matrix_size_{}_" - "batch_size_{}_num_rhs_{}_{}")) + test_name_format_string=("tridiagonal_solve_{}_matrix_size_{}_" + "batch_size_{}_num_rhs_{}_{}")) if __name__ == "__main__": diff --git a/tensorflow/python/ops/linalg/linalg_impl.py b/tensorflow/python/ops/linalg/linalg_impl.py index c59314890e1..f7617d83caf 100644 --- a/tensorflow/python/ops/linalg/linalg_impl.py +++ b/tensorflow/python/ops/linalg/linalg_impl.py @@ -430,7 +430,9 @@ def tridiagonal_solve(diagonals, Raises: ValueError: An unsupported type is provided as input, or when the input - tensors have incorrect shapes. + tensors have incorrect shapes. + UnimplementedError: Whenever `partial_pivoting` is true and the backend is + XLA. [1] Nicholas J. Higham (2002). Accuracy and Stability of Numerical Algorithms: Second Edition. SIAM. p. 175. ISBN 978-0-89871-802-7.