[XLA] Add complex number support to HLO cholesky decomposition implementation.
Fix bug where errors in one batch element would cause other batch elements to fail. PiperOrigin-RevId: 332443397 Change-Id: I868accebbad9df2fa759525f6f0b0b3df6a481c1
This commit is contained in:
parent
ef1567488a
commit
85404d1a85
@ -56,17 +56,21 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
|
||||
XlaOp a, PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int n_dims = a_shape.rank();
|
||||
const int ndims = a_shape.rank();
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
std::vector<int64> error_dims(a_shape.dimensions().begin(),
|
||||
a_shape.dimensions().end());
|
||||
error_dims.back() = error_dims.at(ndims - 2) = 1;
|
||||
|
||||
auto major_dims = AsInt64Slice(a_shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims - 2);
|
||||
/*len=*/ndims - 2);
|
||||
|
||||
auto matrix_dims = AsInt64Slice(a_shape.dimensions())
|
||||
.subspan(
|
||||
/*pos=*/0,
|
||||
/*len=*/n_dims);
|
||||
/*len=*/ndims);
|
||||
|
||||
XlaOp l = ZerosLike(a);
|
||||
|
||||
@ -79,9 +83,9 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
|
||||
auto body_l = loop_vars[1];
|
||||
auto seen_error = loop_vars[2];
|
||||
auto iota_row =
|
||||
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 1);
|
||||
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 1);
|
||||
auto iota_col =
|
||||
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 2);
|
||||
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 2);
|
||||
|
||||
auto mask_pred = Ge(iota_col, iota_row);
|
||||
mask_pred = And(mask_pred, Eq(iota_row, i));
|
||||
@ -91,25 +95,32 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
|
||||
// L * L.T, This matrix has of a lot of multiplying with zero
|
||||
// (namely, L[:, j:] = 0) and redundant computation, but it is faster
|
||||
// than slice.
|
||||
auto l_square = BatchDot(body_l, false, body_l, true, precision);
|
||||
auto l_square =
|
||||
BatchDot(body_l, false, MaybeConjugate(body_l, true), true, precision);
|
||||
|
||||
// A - L*L.T
|
||||
l_square = body_a - l_square;
|
||||
auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1});
|
||||
l_ii = Sqrt(l_ii);
|
||||
if (ShapeUtil::ElementIsComplex(a_shape)) {
|
||||
auto sqrt = Sqrt(Real(l_ii));
|
||||
l_ii = Complex(sqrt, ZerosLike(sqrt));
|
||||
seen_error = Or(seen_error, IsNan(sqrt));
|
||||
} else {
|
||||
l_ii = Sqrt(l_ii);
|
||||
seen_error = Or(seen_error, IsNan(l_ii));
|
||||
}
|
||||
// L = (A - L*L.T) / l_ii * mask + L
|
||||
body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l;
|
||||
|
||||
seen_error =
|
||||
Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii))));
|
||||
|
||||
return std::vector<XlaOp>{body_a, body_l, seen_error};
|
||||
};
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto cholesky_while,
|
||||
ForEachIndex(n, S32, body_fn, {a, l, ConstantR0<bool>(builder, false)},
|
||||
"unblocked", builder));
|
||||
ForEachIndex(
|
||||
n, S32, body_fn,
|
||||
{a, l, Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims))},
|
||||
"unblocked", builder));
|
||||
|
||||
return std::make_pair(cholesky_while[1], cholesky_while[2]);
|
||||
}
|
||||
@ -133,23 +144,23 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
||||
ShapeUtil::HumanString(a_shape));
|
||||
}
|
||||
|
||||
if (primitive_util::IsComplexType(a_shape.element_type())) {
|
||||
return Unimplemented(
|
||||
"Complex types are not implemented in Cholesky; got shape %s",
|
||||
ShapeUtil::HumanString(a_shape));
|
||||
}
|
||||
|
||||
if (block_size < 1) {
|
||||
return InvalidArgument(
|
||||
"block_size argument to Cholesky must be >= 1; got %d", block_size);
|
||||
}
|
||||
|
||||
std::vector<int64> error_dims(a_shape.dimensions().begin(),
|
||||
a_shape.dimensions().end());
|
||||
error_dims.back() = error_dims.at(ndims - 2) = 1;
|
||||
std::vector<int64> error_dim_indices(ndims);
|
||||
absl::c_iota(error_dim_indices, 0);
|
||||
|
||||
// Blocked left-looking Cholesky factorization.
|
||||
// Algorithm 1 from
|
||||
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
|
||||
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
|
||||
XlaOp l = ZerosLike(a);
|
||||
XlaOp seen_error = ConstantR0<bool>(builder, false);
|
||||
XlaOp seen_error = Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims));
|
||||
for (int64 i = 0; i < n; i += block_size) {
|
||||
int64 k = std::min(block_size, n - i);
|
||||
auto panel = SliceInMinorDims(a, {i, i}, {n, i + k});
|
||||
@ -159,7 +170,8 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
||||
// a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
|
||||
auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
|
||||
auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
|
||||
auto delta = BatchDot(lhs, false, rhs, true, precision);
|
||||
auto delta =
|
||||
BatchDot(lhs, false, MaybeConjugate(rhs, true), true, precision);
|
||||
panel = panel - delta;
|
||||
}
|
||||
|
||||
@ -170,8 +182,14 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
||||
// other elements.
|
||||
XlaOp factorized_error;
|
||||
if (k == 1) {
|
||||
factorized = Sqrt(x);
|
||||
factorized_error = Any(IsNan(factorized));
|
||||
if (ShapeUtil::ElementIsComplex(a_shape)) {
|
||||
auto sqrt = Sqrt(Real(x));
|
||||
factorized = Complex(sqrt, ZerosLike(sqrt));
|
||||
factorized_error = IsNan(sqrt);
|
||||
} else {
|
||||
factorized = Sqrt(x);
|
||||
factorized_error = IsNan(factorized);
|
||||
}
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision));
|
||||
std::tie(factorized, factorized_error) = tile_output;
|
||||
@ -187,12 +205,13 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
||||
/*left_side=*/false,
|
||||
/*lower=*/true,
|
||||
/*unit_diagonal=*/false,
|
||||
/*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
|
||||
/*transpose_a=*/TriangularSolveOptions::ADJOINT);
|
||||
l = UpdateSliceInMinorDims(l, update, {i + k, i});
|
||||
}
|
||||
}
|
||||
return Select(seen_error,
|
||||
FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
|
||||
return Select(
|
||||
BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices),
|
||||
FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -2676,6 +2676,7 @@ xla_test(
|
||||
xla_test(
|
||||
name = "cholesky_test",
|
||||
srcs = ["cholesky_test.cc"],
|
||||
real_hardware_only = True,
|
||||
tags = [
|
||||
"no_rocm",
|
||||
"optonly",
|
||||
|
@ -61,6 +61,44 @@ XLA_TEST_F(CholeskyTest, NonPSDInput) {
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CholeskyTest, NonPSDBatched) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
Array3D<float> a_vals({
|
||||
{
|
||||
{10, 0, 0},
|
||||
{1, 20, 0},
|
||||
{1, 1, 30},
|
||||
},
|
||||
{
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
},
|
||||
});
|
||||
|
||||
XlaOp a;
|
||||
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
Cholesky(a, /*lower=*/true);
|
||||
|
||||
float nan = std::numeric_limits<float>::quiet_NaN();
|
||||
Array3D<float> expected({
|
||||
{
|
||||
{3.16227766, 0., 0.},
|
||||
{0.31622777, 4.4609416, 0.},
|
||||
{0.31622777, 0.20175113, 5.46436606},
|
||||
},
|
||||
{
|
||||
{nan, nan, nan},
|
||||
{nan, nan, nan},
|
||||
{nan, nan, nan},
|
||||
},
|
||||
});
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CholeskyTest, Lower) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
@ -181,7 +219,7 @@ class RandomCholeskyTest
|
||||
: public ClientLibraryTestBase,
|
||||
public ::testing::WithParamInterface<CholeskyTestCase> {};
|
||||
|
||||
XLA_TEST_P(RandomCholeskyTest, Random) {
|
||||
XLA_TEST_P(RandomCholeskyTest, Real) {
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tensorflow::enable_tensor_float_32_execution(false);
|
||||
XlaBuilder builder(TestName());
|
||||
@ -220,14 +258,65 @@ XLA_TEST_P(RandomCholeskyTest, Random) {
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_P(RandomCholeskyTest, Complex) {
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tensorflow::enable_tensor_float_32_execution(false);
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
auto test_params = GetParam();
|
||||
std::vector<int64> dimensions = {std::get<0>(test_params),
|
||||
std::get<1>(test_params),
|
||||
std::get<1>(test_params)};
|
||||
bool lower = std::get<2>(test_params);
|
||||
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto literal_real,
|
||||
LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto literal_imag,
|
||||
LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
|
||||
|
||||
auto input_real = Parameter(&builder, 0, shape, "input_real");
|
||||
auto input_imag = Parameter(&builder, 1, shape, "input_imag");
|
||||
auto input = Complex(input_real, input_imag);
|
||||
// Form a random positive definite matrix.
|
||||
auto matrix = BatchDot(input, TransposeInMinorDims(Conj(input)),
|
||||
PrecisionConfig::HIGHEST);
|
||||
|
||||
auto cholesky = Triangle(Cholesky(matrix, lower), lower);
|
||||
|
||||
// Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
|
||||
XlaOp verification;
|
||||
if (lower) {
|
||||
verification = BatchDot(cholesky, TransposeInMinorDims(Conj(cholesky)),
|
||||
PrecisionConfig::HIGHEST);
|
||||
} else {
|
||||
verification = BatchDot(TransposeInMinorDims(Conj(cholesky)), cholesky,
|
||||
PrecisionConfig::HIGHEST);
|
||||
}
|
||||
auto delta = matrix - verification;
|
||||
Reduce(Abs(delta * Conj(delta)), ConstantR0<float>(&builder, 0.0),
|
||||
CreateScalarAddComputation(F32, &builder), {0, 1, 2});
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto input_data_real,
|
||||
client_->TransferToServer(literal_real));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto input_data_imag,
|
||||
client_->TransferToServer(literal_imag));
|
||||
ComputeAndCompareR0<float>(&builder, 0.0,
|
||||
{input_data_real.get(), input_data_imag.get()},
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
|
||||
::testing::Values(CholeskyTestCase{1, 1, true},
|
||||
CholeskyTestCase{1, 2, true},
|
||||
CholeskyTestCase{1, 50, true},
|
||||
CholeskyTestCase{1, 50, false},
|
||||
CholeskyTestCase{1, 255, false},
|
||||
CholeskyTestCase{10, 5, true},
|
||||
CholeskyTestCase{5, 10, false},
|
||||
CholeskyTestCase{2, 20, true}));
|
||||
CholeskyTestCase{2, 20, true},
|
||||
CholeskyTestCase{2, 129, true}));
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
Loading…
Reference in New Issue
Block a user