[XLA] Make the inner block kernel of CholeskyExpander override-able.

Add a special case for the degenerate n=1 case of Cholesky decomposition.

PiperOrigin-RevId: 329317576
Change-Id: Ia6a2567576286fbf04fff2b050f1870946f907e2
This commit is contained in:
Peter Hawkins 2020-08-31 09:30:06 -07:00 committed by TensorFlower Gardener
parent 1619f2f19f
commit c9b76de37e
2 changed files with 73 additions and 74 deletions

View File

@ -35,8 +35,6 @@ limitations under the License.
namespace xla { namespace xla {
namespace {
// The CholeskyBanachiewicz algorithm. See // The CholeskyBanachiewicz algorithm. See
// https://en.wikipedia.org/wiki/Cholesky_decomposition#The_CholeskyBanachiewicz_and_CholeskyCrout_algorithms // https://en.wikipedia.org/wiki/Cholesky_decomposition#The_CholeskyBanachiewicz_and_CholeskyCrout_algorithms
// for a description. // for a description.
@ -54,10 +52,9 @@ namespace {
// l = temp / l[..., j, j) * mask + l // l = temp / l[..., j, j) * mask + l
// return l // return l
// Returns a (result, error) pair. // Returns a (result, error) pair.
std::pair<XlaOp, XlaOp> CholeskyUnblocked( StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
XlaOp a, PrecisionConfig::Precision precision) { XlaOp a, PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder(); XlaBuilder* builder = a.builder();
auto result = [&]() -> StatusOr<std::pair<XlaOp, XlaOp>> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int n_dims = a_shape.rank(); const int n_dims = a_shape.rank();
const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 n = ShapeUtil::GetDimension(a_shape, -1);
@ -74,18 +71,17 @@ std::pair<XlaOp, XlaOp> CholeskyUnblocked(
XlaOp l = ZerosLike(a); XlaOp l = ZerosLike(a);
// Construct the for loop body to iterate over rows. // Construct the for loop body to iterate over rows.
auto body_fn = auto body_fn = [&](XlaOp i, absl::Span<const XlaOp> loop_vars,
[&](XlaOp i, absl::Span<const XlaOp> loop_vars,
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> { XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
std::vector<int64> row_shape_dims(major_dims.begin(), major_dims.end()); std::vector<int64> row_shape_dims(major_dims.begin(), major_dims.end());
std::vector<int64> col_shape_dims(major_dims.begin(), major_dims.end()); std::vector<int64> col_shape_dims(major_dims.begin(), major_dims.end());
auto body_a = loop_vars[0]; auto body_a = loop_vars[0];
auto body_l = loop_vars[1]; auto body_l = loop_vars[1];
auto seen_error = loop_vars[2]; auto seen_error = loop_vars[2];
auto iota_row = Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), auto iota_row =
n_dims - 1); Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 1);
auto iota_col = Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), auto iota_col =
n_dims - 2); Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 2);
auto mask_pred = Ge(iota_col, iota_row); auto mask_pred = Ge(iota_col, iota_row);
mask_pred = And(mask_pred, Eq(iota_row, i)); mask_pred = And(mask_pred, Eq(iota_row, i));
@ -116,15 +112,9 @@ std::pair<XlaOp, XlaOp> CholeskyUnblocked(
"unblocked", builder)); "unblocked", builder));
return std::make_pair(cholesky_while[1], cholesky_while[2]); return std::make_pair(cholesky_while[1], cholesky_while[2]);
}();
if (!result.ok()) {
XlaOp error = builder->ReportError(result.status());
return {error, error};
}
return result.ValueOrDie();
} }
XlaOp BuildCholesky(XlaOp a, int64 block_size, XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
PrecisionConfig::Precision precision) { PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder(); XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
@ -154,6 +144,10 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size,
"block_size argument to Cholesky must be >= 1; got %d", block_size); "block_size argument to Cholesky must be >= 1; got %d", block_size);
} }
if (n == 1) {
return Sqrt(a);
}
// Blocked left-looking Cholesky factorization. // Blocked left-looking Cholesky factorization.
// Algorithm 1 from // Algorithm 1 from
// Haidar, Azzam, et al. "High-performance Cholesky factorization for // Haidar, Azzam, et al. "High-performance Cholesky factorization for
@ -162,6 +156,7 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size,
XlaOp seen_error = ConstantR0<bool>(builder, false); XlaOp seen_error = ConstantR0<bool>(builder, false);
for (int64 i = 0; i < n; i += block_size) { for (int64 i = 0; i < n; i += block_size) {
int64 k = std::min(block_size, n - i); int64 k = std::min(block_size, n - i);
auto panel = SliceInMinorDims(a, {i, i}, {n, i + k});
if (i > 0) { if (i > 0) {
// TODO(phawkins): consider implementing SYRK for the diagonal part of // TODO(phawkins): consider implementing SYRK for the diagonal part of
// the panel. // the panel.
@ -169,24 +164,23 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size,
auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
auto delta = BatchDot(lhs, false, rhs, true, precision); auto delta = BatchDot(lhs, false, rhs, true, precision);
auto before = SliceInMinorDims(a, {i, i}, {n, i + k}); panel = panel - delta;
a = UpdateSliceInMinorDims(a, before - delta, {i, i});
} }
// l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k]) // l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k}); auto x = SliceInMinorDims(panel, {0, 0}, {k, k});
XlaOp factorized; XlaOp factorized;
XlaOp factorized_error; XlaOp factorized_error;
std::tie(factorized, factorized_error) = CholeskyUnblocked(x, precision); TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision));
std::tie(factorized, factorized_error) = tile_output;
seen_error = Or(seen_error, factorized_error); seen_error = Or(seen_error, factorized_error);
l = UpdateSliceInMinorDims(l, factorized, {i, i}); l = UpdateSliceInMinorDims(l, factorized, {i, i});
if (i + k < n) { if (i + k < n) {
// l[i+k:, i:i+k] = // l[i+k:, i:i+k] =
// trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k]) // trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k}); auto update = TriangularSolve(
auto update = factorized, SliceInMinorDims(panel, {k, 0}, {n - i, k}),
TriangularSolve(factorized, panel,
/*left_side=*/false, /*left_side=*/false,
/*lower=*/true, /*lower=*/true,
/*unit_diagonal=*/false, /*unit_diagonal=*/false,
@ -199,8 +193,6 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size,
}); });
} }
} // namespace
bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) { bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCholesky; return instruction->opcode() == HloOpcode::kCholesky;
} }

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_
#include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/op_expander_pass.h" #include "tensorflow/compiler/xla/service/op_expander_pass.h"
namespace xla { namespace xla {
@ -31,7 +32,13 @@ class CholeskyExpander : public OpExpanderPass {
StatusOr<HloInstruction*> ExpandInstruction( StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* instruction) override; HloInstruction* instruction) override;
virtual StatusOr<std::pair<XlaOp, XlaOp>> CholeskyUnblocked(
XlaOp a, PrecisionConfig::Precision precision);
private: private:
XlaOp BuildCholesky(XlaOp a, int64 block_size,
PrecisionConfig::Precision precision);
// Mapping from op signatures to existing computations. // Mapping from op signatures to existing computations.
absl::flat_hash_map<string, HloComputation*> computation_cache_; absl::flat_hash_map<string, HloComputation*> computation_cache_;
}; };