From c9b76de37efbc1abd3521a2d24d1309de16eedc9 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 31 Aug 2020 09:30:06 -0700 Subject: [PATCH] [XLA] Make the inner block kernel of CholeskyExpander override-able. Add a special case for the degenerate n=1 case of Cholesky decomposition. PiperOrigin-RevId: 329317576 Change-Id: Ia6a2567576286fbf04fff2b050f1870946f907e2 --- .../compiler/xla/service/cholesky_expander.cc | 140 +++++++++--------- .../compiler/xla/service/cholesky_expander.h | 7 + 2 files changed, 73 insertions(+), 74 deletions(-) diff --git a/tensorflow/compiler/xla/service/cholesky_expander.cc b/tensorflow/compiler/xla/service/cholesky_expander.cc index 20576cdc52d..8d54b02ad52 100644 --- a/tensorflow/compiler/xla/service/cholesky_expander.cc +++ b/tensorflow/compiler/xla/service/cholesky_expander.cc @@ -35,8 +35,6 @@ limitations under the License. namespace xla { -namespace { - // The Cholesky–Banachiewicz algorithm. See // https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms // for a description. @@ -54,78 +52,70 @@ namespace { // l = temp / l[..., j, j) * mask + l // return l // Returns a (result, error) pair. -std::pair CholeskyUnblocked( +StatusOr> CholeskyExpander::CholeskyUnblocked( XlaOp a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); - auto result = [&]() -> StatusOr> { - TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int n_dims = a_shape.rank(); - const int64 n = ShapeUtil::GetDimension(a_shape, -1); - auto major_dims = AsInt64Slice(a_shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims - 2); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int n_dims = a_shape.rank(); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + auto major_dims = AsInt64Slice(a_shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims - 2); - auto matrix_dims = AsInt64Slice(a_shape.dimensions()) - .subspan( - /*pos=*/0, - /*len=*/n_dims); + auto matrix_dims = AsInt64Slice(a_shape.dimensions()) + .subspan( + /*pos=*/0, + /*len=*/n_dims); - XlaOp l = ZerosLike(a); + XlaOp l = ZerosLike(a); - // Construct the for loop body to iterate over rows. - auto body_fn = - [&](XlaOp i, absl::Span loop_vars, - XlaBuilder* body_builder) -> StatusOr> { - std::vector row_shape_dims(major_dims.begin(), major_dims.end()); - std::vector col_shape_dims(major_dims.begin(), major_dims.end()); - auto body_a = loop_vars[0]; - auto body_l = loop_vars[1]; - auto seen_error = loop_vars[2]; - auto iota_row = Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), - n_dims - 1); - auto iota_col = Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), - n_dims - 2); + // Construct the for loop body to iterate over rows. + auto body_fn = [&](XlaOp i, absl::Span loop_vars, + XlaBuilder* body_builder) -> StatusOr> { + std::vector row_shape_dims(major_dims.begin(), major_dims.end()); + std::vector col_shape_dims(major_dims.begin(), major_dims.end()); + auto body_a = loop_vars[0]; + auto body_l = loop_vars[1]; + auto seen_error = loop_vars[2]; + auto iota_row = + Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 1); + auto iota_col = + Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 2); - auto mask_pred = Ge(iota_col, iota_row); - mask_pred = And(mask_pred, Eq(iota_row, i)); - auto mask_zeros = - Zeros(body_builder, - ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims)); - // L * L.T, This matrix has of a lot of multiplying with zero - // (namely, L[:, j:] = 0) and redundant computation, but it is faster - // than slice. - auto l_square = BatchDot(body_l, false, body_l, true, precision); + auto mask_pred = Ge(iota_col, iota_row); + mask_pred = And(mask_pred, Eq(iota_row, i)); + auto mask_zeros = + Zeros(body_builder, + ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims)); + // L * L.T, This matrix has of a lot of multiplying with zero + // (namely, L[:, j:] = 0) and redundant computation, but it is faster + // than slice. + auto l_square = BatchDot(body_l, false, body_l, true, precision); - // A - L*L.T - l_square = body_a - l_square; - auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1}); - l_ii = Sqrt(l_ii); - // L = (A - L*L.T) / l_ii * mask + L - body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l; + // A - L*L.T + l_square = body_a - l_square; + auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1}); + l_ii = Sqrt(l_ii); + // L = (A - L*L.T) / l_ii * mask + L + body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l; - seen_error = - Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii)))); + seen_error = + Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii)))); - return std::vector{body_a, body_l, seen_error}; - }; + return std::vector{body_a, body_l, seen_error}; + }; - TF_ASSIGN_OR_RETURN( - auto cholesky_while, - ForEachIndex(n, S32, body_fn, {a, l, ConstantR0(builder, false)}, - "unblocked", builder)); + TF_ASSIGN_OR_RETURN( + auto cholesky_while, + ForEachIndex(n, S32, body_fn, {a, l, ConstantR0(builder, false)}, + "unblocked", builder)); - return std::make_pair(cholesky_while[1], cholesky_while[2]); - }(); - if (!result.ok()) { - XlaOp error = builder->ReportError(result.status()); - return {error, error}; - } - return result.ValueOrDie(); + return std::make_pair(cholesky_while[1], cholesky_while[2]); } -XlaOp BuildCholesky(XlaOp a, int64 block_size, - PrecisionConfig::Precision precision) { +XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -154,6 +144,10 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size, "block_size argument to Cholesky must be >= 1; got %d", block_size); } + if (n == 1) { + return Sqrt(a); + } + // Blocked left-looking Cholesky factorization. // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for @@ -162,6 +156,7 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size, XlaOp seen_error = ConstantR0(builder, false); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); + auto panel = SliceInMinorDims(a, {i, i}, {n, i + k}); if (i > 0) { // TODO(phawkins): consider implementing SYRK for the diagonal part of // the panel. @@ -169,28 +164,27 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size, auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); auto delta = BatchDot(lhs, false, rhs, true, precision); - auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); - a = UpdateSliceInMinorDims(a, before - delta, {i, i}); + panel = panel - delta; } // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) - auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); + auto x = SliceInMinorDims(panel, {0, 0}, {k, k}); XlaOp factorized; XlaOp factorized_error; - std::tie(factorized, factorized_error) = CholeskyUnblocked(x, precision); + TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision)); + std::tie(factorized, factorized_error) = tile_output; seen_error = Or(seen_error, factorized_error); l = UpdateSliceInMinorDims(l, factorized, {i, i}); if (i + k < n) { // l[i+k:, i:i+k] = // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) - auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k}); - auto update = - TriangularSolve(factorized, panel, - /*left_side=*/false, - /*lower=*/true, - /*unit_diagonal=*/false, - /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); + auto update = TriangularSolve( + factorized, SliceInMinorDims(panel, {k, 0}, {n - i, k}), + /*left_side=*/false, + /*lower=*/true, + /*unit_diagonal=*/false, + /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); l = UpdateSliceInMinorDims(l, update, {i + k, i}); } } @@ -199,8 +193,6 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size, }); } -} // namespace - bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) { return instruction->opcode() == HloOpcode::kCholesky; } diff --git a/tensorflow/compiler/xla/service/cholesky_expander.h b/tensorflow/compiler/xla/service/cholesky_expander.h index d2958db1b8c..ee8531d0f48 100644 --- a/tensorflow/compiler/xla/service/cholesky_expander.h +++ b/tensorflow/compiler/xla/service/cholesky_expander.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_ #include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/op_expander_pass.h" namespace xla { @@ -31,7 +32,13 @@ class CholeskyExpander : public OpExpanderPass { StatusOr ExpandInstruction( HloInstruction* instruction) override; + virtual StatusOr> CholeskyUnblocked( + XlaOp a, PrecisionConfig::Precision precision); + private: + XlaOp BuildCholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision); + // Mapping from op signatures to existing computations. absl::flat_hash_map computation_cache_; };