[XLA] Make the inner block kernel of CholeskyExpander override-able.
Add a special case for the degenerate n=1 case of Cholesky decomposition. PiperOrigin-RevId: 329317576 Change-Id: Ia6a2567576286fbf04fff2b050f1870946f907e2
This commit is contained in:
parent
1619f2f19f
commit
c9b76de37e
@ -35,8 +35,6 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
// The Cholesky–Banachiewicz algorithm. See
|
||||
// https://en.wikipedia.org/wiki/Cholesky_decomposition#The_Cholesky–Banachiewicz_and_Cholesky–Crout_algorithms
|
||||
// for a description.
|
||||
@ -54,78 +52,70 @@ namespace {
|
||||
// l = temp / l[..., j, j) * mask + l
|
||||
// return l
|
||||
// Returns a (result, error) pair.
|
||||
std::pair<XlaOp, XlaOp> CholeskyUnblocked(
|
||||
StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
|
||||
XlaOp a, PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
auto result = [&]() -> StatusOr<std::pair<XlaOp, XlaOp>> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int n_dims = a_shape.rank();
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
auto major_dims = AsInt64Slice(a_shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims - 2);
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int n_dims = a_shape.rank();
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
auto major_dims = AsInt64Slice(a_shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims - 2);
|
||||
|
||||
auto matrix_dims = AsInt64Slice(a_shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims);
|
||||
auto matrix_dims = AsInt64Slice(a_shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims);
|
||||
|
||||
XlaOp l = ZerosLike(a);
|
||||
XlaOp l = ZerosLike(a);
|
||||
|
||||
// Construct the for loop body to iterate over rows.
|
||||
auto body_fn =
|
||||
[&](XlaOp i, absl::Span<const XlaOp> loop_vars,
|
||||
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
std::vector<int64> row_shape_dims(major_dims.begin(), major_dims.end());
|
||||
std::vector<int64> col_shape_dims(major_dims.begin(), major_dims.end());
|
||||
auto body_a = loop_vars[0];
|
||||
auto body_l = loop_vars[1];
|
||||
auto seen_error = loop_vars[2];
|
||||
auto iota_row = Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims),
|
||||
n_dims - 1);
|
||||
auto iota_col = Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims),
|
||||
n_dims - 2);
|
||||
// Construct the for loop body to iterate over rows.
|
||||
auto body_fn = [&](XlaOp i, absl::Span<const XlaOp> loop_vars,
|
||||
XlaBuilder* body_builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
std::vector<int64> row_shape_dims(major_dims.begin(), major_dims.end());
|
||||
std::vector<int64> col_shape_dims(major_dims.begin(), major_dims.end());
|
||||
auto body_a = loop_vars[0];
|
||||
auto body_l = loop_vars[1];
|
||||
auto seen_error = loop_vars[2];
|
||||
auto iota_row =
|
||||
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 1);
|
||||
auto iota_col =
|
||||
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 2);
|
||||
|
||||
auto mask_pred = Ge(iota_col, iota_row);
|
||||
mask_pred = And(mask_pred, Eq(iota_row, i));
|
||||
auto mask_zeros =
|
||||
Zeros(body_builder,
|
||||
ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims));
|
||||
// L * L.T, This matrix has of a lot of multiplying with zero
|
||||
// (namely, L[:, j:] = 0) and redundant computation, but it is faster
|
||||
// than slice.
|
||||
auto l_square = BatchDot(body_l, false, body_l, true, precision);
|
||||
auto mask_pred = Ge(iota_col, iota_row);
|
||||
mask_pred = And(mask_pred, Eq(iota_row, i));
|
||||
auto mask_zeros =
|
||||
Zeros(body_builder,
|
||||
ShapeUtil::MakeShape(a_shape.element_type(), matrix_dims));
|
||||
// L * L.T, This matrix has of a lot of multiplying with zero
|
||||
// (namely, L[:, j:] = 0) and redundant computation, but it is faster
|
||||
// than slice.
|
||||
auto l_square = BatchDot(body_l, false, body_l, true, precision);
|
||||
|
||||
// A - L*L.T
|
||||
l_square = body_a - l_square;
|
||||
auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1});
|
||||
l_ii = Sqrt(l_ii);
|
||||
// L = (A - L*L.T) / l_ii * mask + L
|
||||
body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l;
|
||||
// A - L*L.T
|
||||
l_square = body_a - l_square;
|
||||
auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1});
|
||||
l_ii = Sqrt(l_ii);
|
||||
// L = (A - L*L.T) / l_ii * mask + L
|
||||
body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l;
|
||||
|
||||
seen_error =
|
||||
Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii))));
|
||||
seen_error =
|
||||
Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii))));
|
||||
|
||||
return std::vector<XlaOp>{body_a, body_l, seen_error};
|
||||
};
|
||||
return std::vector<XlaOp>{body_a, body_l, seen_error};
|
||||
};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cholesky_while,
|
||||
ForEachIndex(n, S32, body_fn, {a, l, ConstantR0<bool>(builder, false)},
|
||||
"unblocked", builder));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cholesky_while,
|
||||
ForEachIndex(n, S32, body_fn, {a, l, ConstantR0<bool>(builder, false)},
|
||||
"unblocked", builder));
|
||||
|
||||
return std::make_pair(cholesky_while[1], cholesky_while[2]);
|
||||
}();
|
||||
if (!result.ok()) {
|
||||
XlaOp error = builder->ReportError(result.status());
|
||||
return {error, error};
|
||||
}
|
||||
return result.ValueOrDie();
|
||||
return std::make_pair(cholesky_while[1], cholesky_while[2]);
|
||||
}
|
||||
|
||||
XlaOp BuildCholesky(XlaOp a, int64 block_size,
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
@ -154,6 +144,10 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size,
|
||||
"block_size argument to Cholesky must be >= 1; got %d", block_size);
|
||||
}
|
||||
|
||||
if (n == 1) {
|
||||
return Sqrt(a);
|
||||
}
|
||||
|
||||
// Blocked left-looking Cholesky factorization.
|
||||
// Algorithm 1 from
|
||||
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
|
||||
@ -162,6 +156,7 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size,
|
||||
XlaOp seen_error = ConstantR0<bool>(builder, false);
|
||||
for (int64 i = 0; i < n; i += block_size) {
|
||||
int64 k = std::min(block_size, n - i);
|
||||
auto panel = SliceInMinorDims(a, {i, i}, {n, i + k});
|
||||
if (i > 0) {
|
||||
// TODO(phawkins): consider implementing SYRK for the diagonal part of
|
||||
// the panel.
|
||||
@ -169,28 +164,27 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size,
|
||||
auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
|
||||
auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
|
||||
auto delta = BatchDot(lhs, false, rhs, true, precision);
|
||||
auto before = SliceInMinorDims(a, {i, i}, {n, i + k});
|
||||
a = UpdateSliceInMinorDims(a, before - delta, {i, i});
|
||||
panel = panel - delta;
|
||||
}
|
||||
|
||||
// l[i:i+k, i:i+k] = cholesky_unblocked(a[i:i+k, i:i+k])
|
||||
auto x = SliceInMinorDims(a, {i, i}, {i + k, i + k});
|
||||
auto x = SliceInMinorDims(panel, {0, 0}, {k, k});
|
||||
XlaOp factorized;
|
||||
XlaOp factorized_error;
|
||||
std::tie(factorized, factorized_error) = CholeskyUnblocked(x, precision);
|
||||
TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision));
|
||||
std::tie(factorized, factorized_error) = tile_output;
|
||||
seen_error = Or(seen_error, factorized_error);
|
||||
l = UpdateSliceInMinorDims(l, factorized, {i, i});
|
||||
|
||||
if (i + k < n) {
|
||||
// l[i+k:, i:i+k] =
|
||||
// trsm_right_transpose(l[i:i+k, i:i+k], a[i+k:, i:i+k])
|
||||
auto panel = SliceInMinorDims(a, {i + k, i}, {n, i + k});
|
||||
auto update =
|
||||
TriangularSolve(factorized, panel,
|
||||
/*left_side=*/false,
|
||||
/*lower=*/true,
|
||||
/*unit_diagonal=*/false,
|
||||
/*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
|
||||
auto update = TriangularSolve(
|
||||
factorized, SliceInMinorDims(panel, {k, 0}, {n - i, k}),
|
||||
/*left_side=*/false,
|
||||
/*lower=*/true,
|
||||
/*unit_diagonal=*/false,
|
||||
/*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
|
||||
l = UpdateSliceInMinorDims(l, update, {i + k, i});
|
||||
}
|
||||
}
|
||||
@ -199,8 +193,6 @@ XlaOp BuildCholesky(XlaOp a, int64 block_size,
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCholesky;
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||
|
||||
namespace xla {
|
||||
@ -31,7 +32,13 @@ class CholeskyExpander : public OpExpanderPass {
|
||||
StatusOr<HloInstruction*> ExpandInstruction(
|
||||
HloInstruction* instruction) override;
|
||||
|
||||
virtual StatusOr<std::pair<XlaOp, XlaOp>> CholeskyUnblocked(
|
||||
XlaOp a, PrecisionConfig::Precision precision);
|
||||
|
||||
private:
|
||||
XlaOp BuildCholesky(XlaOp a, int64 block_size,
|
||||
PrecisionConfig::Precision precision);
|
||||
|
||||
// Mapping from op signatures to existing computations.
|
||||
absl::flat_hash_map<string, HloComputation*> computation_cache_;
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user