[XLA] Add complex number support to HLO cholesky decomposition implementation.

Fix bug where errors in one batch element would cause other batch elements to fail.

PiperOrigin-RevId: 332443397
Change-Id: I868accebbad9df2fa759525f6f0b0b3df6a481c1
This commit is contained in:
Peter Hawkins 2020-09-18 07:01:55 -07:00 committed by TensorFlower Gardener
parent ef1567488a
commit 85404d1a85
3 changed files with 136 additions and 27 deletions

View File

@ -56,17 +56,21 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
XlaOp a, PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int n_dims = a_shape.rank();
const int ndims = a_shape.rank();
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
std::vector<int64> error_dims(a_shape.dimensions().begin(),
a_shape.dimensions().end());
error_dims.back() = error_dims.at(ndims - 2) = 1;
auto major_dims = AsInt64Slice(a_shape.dimensions())
.subspan(
/*pos=*/0,
/*len=*/n_dims - 2);
/*len=*/ndims - 2);
auto matrix_dims = AsInt64Slice(a_shape.dimensions())
.subspan(
/*pos=*/0,
/*len=*/n_dims);
/*len=*/ndims);
XlaOp l = ZerosLike(a);
@ -79,9 +83,9 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
auto body_l = loop_vars[1];
auto seen_error = loop_vars[2];
auto iota_row =
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 1);
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 1);
auto iota_col =
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 2);
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 2);
auto mask_pred = Ge(iota_col, iota_row);
mask_pred = And(mask_pred, Eq(iota_row, i));
@ -91,25 +95,32 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
// L * L.T, This matrix has of a lot of multiplying with zero
// (namely, L[:, j:] = 0) and redundant computation, but it is faster
// than slice.
auto l_square = BatchDot(body_l, false, body_l, true, precision);
auto l_square =
BatchDot(body_l, false, MaybeConjugate(body_l, true), true, precision);
// A - L*L.T
l_square = body_a - l_square;
auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1});
l_ii = Sqrt(l_ii);
if (ShapeUtil::ElementIsComplex(a_shape)) {
auto sqrt = Sqrt(Real(l_ii));
l_ii = Complex(sqrt, ZerosLike(sqrt));
seen_error = Or(seen_error, IsNan(sqrt));
} else {
l_ii = Sqrt(l_ii);
seen_error = Or(seen_error, IsNan(l_ii));
}
// L = (A - L*L.T) / l_ii * mask + L
body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l;
seen_error =
Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii))));
return std::vector<XlaOp>{body_a, body_l, seen_error};
};
TF_ASSIGN_OR_RETURN(
auto cholesky_while,
ForEachIndex(n, S32, body_fn, {a, l, ConstantR0<bool>(builder, false)},
"unblocked", builder));
ForEachIndex(
n, S32, body_fn,
{a, l, Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims))},
"unblocked", builder));
return std::make_pair(cholesky_while[1], cholesky_while[2]);
}
@ -133,23 +144,23 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
ShapeUtil::HumanString(a_shape));
}
if (primitive_util::IsComplexType(a_shape.element_type())) {
return Unimplemented(
"Complex types are not implemented in Cholesky; got shape %s",
ShapeUtil::HumanString(a_shape));
}
if (block_size < 1) {
return InvalidArgument(
"block_size argument to Cholesky must be >= 1; got %d", block_size);
}
std::vector<int64> error_dims(a_shape.dimensions().begin(),
a_shape.dimensions().end());
error_dims.back() = error_dims.at(ndims - 2) = 1;
std::vector<int64> error_dim_indices(ndims);
absl::c_iota(error_dim_indices, 0);
// Blocked left-looking Cholesky factorization.
// Algorithm 1 from
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
XlaOp l = ZerosLike(a);
XlaOp seen_error = ConstantR0<bool>(builder, false);
XlaOp seen_error = Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims));
for (int64 i = 0; i < n; i += block_size) {
int64 k = std::min(block_size, n - i);
auto panel = SliceInMinorDims(a, {i, i}, {n, i + k});
@ -159,7 +170,8 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
// a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
auto delta = BatchDot(lhs, false, rhs, true, precision);
auto delta =
BatchDot(lhs, false, MaybeConjugate(rhs, true), true, precision);
panel = panel - delta;
}
@ -170,8 +182,14 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
// other elements.
XlaOp factorized_error;
if (k == 1) {
factorized = Sqrt(x);
factorized_error = Any(IsNan(factorized));
if (ShapeUtil::ElementIsComplex(a_shape)) {
auto sqrt = Sqrt(Real(x));
factorized = Complex(sqrt, ZerosLike(sqrt));
factorized_error = IsNan(sqrt);
} else {
factorized = Sqrt(x);
factorized_error = IsNan(factorized);
}
} else {
TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision));
std::tie(factorized, factorized_error) = tile_output;
@ -187,12 +205,13 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
/*left_side=*/false,
/*lower=*/true,
/*unit_diagonal=*/false,
/*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
/*transpose_a=*/TriangularSolveOptions::ADJOINT);
l = UpdateSliceInMinorDims(l, update, {i + k, i});
}
}
return Select(seen_error,
FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
return Select(
BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices),
FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
});
}

View File

@ -2676,6 +2676,7 @@ xla_test(
xla_test(
name = "cholesky_test",
srcs = ["cholesky_test.cc"],
real_hardware_only = True,
tags = [
"no_rocm",
"optonly",

View File

@ -61,6 +61,44 @@ XLA_TEST_F(CholeskyTest, NonPSDInput) {
ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(CholeskyTest, NonPSDBatched) {
XlaBuilder builder(TestName());
Array3D<float> a_vals({
{
{10, 0, 0},
{1, 20, 0},
{1, 1, 30},
},
{
{1, 1, 1},
{1, 1, 1},
{1, 1, 1},
},
});
XlaOp a;
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
Cholesky(a, /*lower=*/true);
float nan = std::numeric_limits<float>::quiet_NaN();
Array3D<float> expected({
{
{3.16227766, 0., 0.},
{0.31622777, 4.4609416, 0.},
{0.31622777, 0.20175113, 5.46436606},
},
{
{nan, nan, nan},
{nan, nan, nan},
{nan, nan, nan},
},
});
ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(CholeskyTest, Lower) {
XlaBuilder builder(TestName());
@ -181,7 +219,7 @@ class RandomCholeskyTest
: public ClientLibraryTestBase,
public ::testing::WithParamInterface<CholeskyTestCase> {};
XLA_TEST_P(RandomCholeskyTest, Random) {
XLA_TEST_P(RandomCholeskyTest, Real) {
// Test fails with TensorFloat-32 enabled
tensorflow::enable_tensor_float_32_execution(false);
XlaBuilder builder(TestName());
@ -220,14 +258,65 @@ XLA_TEST_P(RandomCholeskyTest, Random) {
ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_P(RandomCholeskyTest, Complex) {
// Test fails with TensorFloat-32 enabled
tensorflow::enable_tensor_float_32_execution(false);
XlaBuilder builder(TestName());
auto test_params = GetParam();
std::vector<int64> dimensions = {std::get<0>(test_params),
std::get<1>(test_params),
std::get<1>(test_params)};
bool lower = std::get<2>(test_params);
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
TF_ASSERT_OK_AND_ASSIGN(
auto literal_real,
LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
TF_ASSERT_OK_AND_ASSIGN(
auto literal_imag,
LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
auto input_real = Parameter(&builder, 0, shape, "input_real");
auto input_imag = Parameter(&builder, 1, shape, "input_imag");
auto input = Complex(input_real, input_imag);
// Form a random positive definite matrix.
auto matrix = BatchDot(input, TransposeInMinorDims(Conj(input)),
PrecisionConfig::HIGHEST);
auto cholesky = Triangle(Cholesky(matrix, lower), lower);
// Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
XlaOp verification;
if (lower) {
verification = BatchDot(cholesky, TransposeInMinorDims(Conj(cholesky)),
PrecisionConfig::HIGHEST);
} else {
verification = BatchDot(TransposeInMinorDims(Conj(cholesky)), cholesky,
PrecisionConfig::HIGHEST);
}
auto delta = matrix - verification;
Reduce(Abs(delta * Conj(delta)), ConstantR0<float>(&builder, 0.0),
CreateScalarAddComputation(F32, &builder), {0, 1, 2});
TF_ASSERT_OK_AND_ASSIGN(auto input_data_real,
client_->TransferToServer(literal_real));
TF_ASSERT_OK_AND_ASSIGN(auto input_data_imag,
client_->TransferToServer(literal_imag));
ComputeAndCompareR0<float>(&builder, 0.0,
{input_data_real.get(), input_data_imag.get()},
ErrorSpec(1e-4, 1e-4));
}
INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
::testing::Values(CholeskyTestCase{1, 1, true},
CholeskyTestCase{1, 2, true},
CholeskyTestCase{1, 50, true},
CholeskyTestCase{1, 50, false},
CholeskyTestCase{1, 255, false},
CholeskyTestCase{10, 5, true},
CholeskyTestCase{5, 10, false},
CholeskyTestCase{2, 20, true}));
CholeskyTestCase{2, 20, true},
CholeskyTestCase{2, 129, true}));
} // namespace
} // namespace xla