[XLA] Add complex number support to HLO cholesky decomposition implementation.
Fix bug where errors in one batch element would cause other batch elements to fail. PiperOrigin-RevId: 332443397 Change-Id: I868accebbad9df2fa759525f6f0b0b3df6a481c1
This commit is contained in:
parent
ef1567488a
commit
85404d1a85
@ -56,17 +56,21 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
|
|||||||
XlaOp a, PrecisionConfig::Precision precision) {
|
XlaOp a, PrecisionConfig::Precision precision) {
|
||||||
XlaBuilder* builder = a.builder();
|
XlaBuilder* builder = a.builder();
|
||||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||||
const int n_dims = a_shape.rank();
|
const int ndims = a_shape.rank();
|
||||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||||
|
std::vector<int64> error_dims(a_shape.dimensions().begin(),
|
||||||
|
a_shape.dimensions().end());
|
||||||
|
error_dims.back() = error_dims.at(ndims - 2) = 1;
|
||||||
|
|
||||||
auto major_dims = AsInt64Slice(a_shape.dimensions())
|
auto major_dims = AsInt64Slice(a_shape.dimensions())
|
||||||
.subspan(
|
.subspan(
|
||||||
/*pos=*/0,
|
/*pos=*/0,
|
||||||
/*len=*/n_dims - 2);
|
/*len=*/ndims - 2);
|
||||||
|
|
||||||
auto matrix_dims = AsInt64Slice(a_shape.dimensions())
|
auto matrix_dims = AsInt64Slice(a_shape.dimensions())
|
||||||
.subspan(
|
.subspan(
|
||||||
/*pos=*/0,
|
/*pos=*/0,
|
||||||
/*len=*/n_dims);
|
/*len=*/ndims);
|
||||||
|
|
||||||
XlaOp l = ZerosLike(a);
|
XlaOp l = ZerosLike(a);
|
||||||
|
|
||||||
@ -79,9 +83,9 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
|
|||||||
auto body_l = loop_vars[1];
|
auto body_l = loop_vars[1];
|
||||||
auto seen_error = loop_vars[2];
|
auto seen_error = loop_vars[2];
|
||||||
auto iota_row =
|
auto iota_row =
|
||||||
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 1);
|
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 1);
|
||||||
auto iota_col =
|
auto iota_col =
|
||||||
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 2);
|
Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 2);
|
||||||
|
|
||||||
auto mask_pred = Ge(iota_col, iota_row);
|
auto mask_pred = Ge(iota_col, iota_row);
|
||||||
mask_pred = And(mask_pred, Eq(iota_row, i));
|
mask_pred = And(mask_pred, Eq(iota_row, i));
|
||||||
@ -91,25 +95,32 @@ StatusOr<std::pair<XlaOp, XlaOp>> CholeskyExpander::CholeskyUnblocked(
|
|||||||
// L * L.T, This matrix has of a lot of multiplying with zero
|
// L * L.T, This matrix has of a lot of multiplying with zero
|
||||||
// (namely, L[:, j:] = 0) and redundant computation, but it is faster
|
// (namely, L[:, j:] = 0) and redundant computation, but it is faster
|
||||||
// than slice.
|
// than slice.
|
||||||
auto l_square = BatchDot(body_l, false, body_l, true, precision);
|
auto l_square =
|
||||||
|
BatchDot(body_l, false, MaybeConjugate(body_l, true), true, precision);
|
||||||
|
|
||||||
// A - L*L.T
|
// A - L*L.T
|
||||||
l_square = body_a - l_square;
|
l_square = body_a - l_square;
|
||||||
auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1});
|
auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1});
|
||||||
l_ii = Sqrt(l_ii);
|
if (ShapeUtil::ElementIsComplex(a_shape)) {
|
||||||
|
auto sqrt = Sqrt(Real(l_ii));
|
||||||
|
l_ii = Complex(sqrt, ZerosLike(sqrt));
|
||||||
|
seen_error = Or(seen_error, IsNan(sqrt));
|
||||||
|
} else {
|
||||||
|
l_ii = Sqrt(l_ii);
|
||||||
|
seen_error = Or(seen_error, IsNan(l_ii));
|
||||||
|
}
|
||||||
// L = (A - L*L.T) / l_ii * mask + L
|
// L = (A - L*L.T) / l_ii * mask + L
|
||||||
body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l;
|
body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l;
|
||||||
|
|
||||||
seen_error =
|
|
||||||
Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii))));
|
|
||||||
|
|
||||||
return std::vector<XlaOp>{body_a, body_l, seen_error};
|
return std::vector<XlaOp>{body_a, body_l, seen_error};
|
||||||
};
|
};
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto cholesky_while,
|
auto cholesky_while,
|
||||||
ForEachIndex(n, S32, body_fn, {a, l, ConstantR0<bool>(builder, false)},
|
ForEachIndex(
|
||||||
"unblocked", builder));
|
n, S32, body_fn,
|
||||||
|
{a, l, Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims))},
|
||||||
|
"unblocked", builder));
|
||||||
|
|
||||||
return std::make_pair(cholesky_while[1], cholesky_while[2]);
|
return std::make_pair(cholesky_while[1], cholesky_while[2]);
|
||||||
}
|
}
|
||||||
@ -133,23 +144,23 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
|||||||
ShapeUtil::HumanString(a_shape));
|
ShapeUtil::HumanString(a_shape));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (primitive_util::IsComplexType(a_shape.element_type())) {
|
|
||||||
return Unimplemented(
|
|
||||||
"Complex types are not implemented in Cholesky; got shape %s",
|
|
||||||
ShapeUtil::HumanString(a_shape));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (block_size < 1) {
|
if (block_size < 1) {
|
||||||
return InvalidArgument(
|
return InvalidArgument(
|
||||||
"block_size argument to Cholesky must be >= 1; got %d", block_size);
|
"block_size argument to Cholesky must be >= 1; got %d", block_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<int64> error_dims(a_shape.dimensions().begin(),
|
||||||
|
a_shape.dimensions().end());
|
||||||
|
error_dims.back() = error_dims.at(ndims - 2) = 1;
|
||||||
|
std::vector<int64> error_dim_indices(ndims);
|
||||||
|
absl::c_iota(error_dim_indices, 0);
|
||||||
|
|
||||||
// Blocked left-looking Cholesky factorization.
|
// Blocked left-looking Cholesky factorization.
|
||||||
// Algorithm 1 from
|
// Algorithm 1 from
|
||||||
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
|
// Haidar, Azzam, et al. "High-performance Cholesky factorization for
|
||||||
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
|
// GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017.
|
||||||
XlaOp l = ZerosLike(a);
|
XlaOp l = ZerosLike(a);
|
||||||
XlaOp seen_error = ConstantR0<bool>(builder, false);
|
XlaOp seen_error = Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims));
|
||||||
for (int64 i = 0; i < n; i += block_size) {
|
for (int64 i = 0; i < n; i += block_size) {
|
||||||
int64 k = std::min(block_size, n - i);
|
int64 k = std::min(block_size, n - i);
|
||||||
auto panel = SliceInMinorDims(a, {i, i}, {n, i + k});
|
auto panel = SliceInMinorDims(a, {i, i}, {n, i + k});
|
||||||
@ -159,7 +170,8 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
|||||||
// a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
|
// a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i]))
|
||||||
auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
|
auto lhs = SliceInMinorDims(l, {i, 0}, {n, i});
|
||||||
auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
|
auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i});
|
||||||
auto delta = BatchDot(lhs, false, rhs, true, precision);
|
auto delta =
|
||||||
|
BatchDot(lhs, false, MaybeConjugate(rhs, true), true, precision);
|
||||||
panel = panel - delta;
|
panel = panel - delta;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,8 +182,14 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
|||||||
// other elements.
|
// other elements.
|
||||||
XlaOp factorized_error;
|
XlaOp factorized_error;
|
||||||
if (k == 1) {
|
if (k == 1) {
|
||||||
factorized = Sqrt(x);
|
if (ShapeUtil::ElementIsComplex(a_shape)) {
|
||||||
factorized_error = Any(IsNan(factorized));
|
auto sqrt = Sqrt(Real(x));
|
||||||
|
factorized = Complex(sqrt, ZerosLike(sqrt));
|
||||||
|
factorized_error = IsNan(sqrt);
|
||||||
|
} else {
|
||||||
|
factorized = Sqrt(x);
|
||||||
|
factorized_error = IsNan(factorized);
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision));
|
TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision));
|
||||||
std::tie(factorized, factorized_error) = tile_output;
|
std::tie(factorized, factorized_error) = tile_output;
|
||||||
@ -187,12 +205,13 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size,
|
|||||||
/*left_side=*/false,
|
/*left_side=*/false,
|
||||||
/*lower=*/true,
|
/*lower=*/true,
|
||||||
/*unit_diagonal=*/false,
|
/*unit_diagonal=*/false,
|
||||||
/*transpose_a=*/TriangularSolveOptions::TRANSPOSE);
|
/*transpose_a=*/TriangularSolveOptions::ADJOINT);
|
||||||
l = UpdateSliceInMinorDims(l, update, {i + k, i});
|
l = UpdateSliceInMinorDims(l, update, {i + k, i});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Select(seen_error,
|
return Select(
|
||||||
FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
|
BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices),
|
||||||
|
FullLike(l, std::numeric_limits<float>::quiet_NaN()), l);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2676,6 +2676,7 @@ xla_test(
|
|||||||
xla_test(
|
xla_test(
|
||||||
name = "cholesky_test",
|
name = "cholesky_test",
|
||||||
srcs = ["cholesky_test.cc"],
|
srcs = ["cholesky_test.cc"],
|
||||||
|
real_hardware_only = True,
|
||||||
tags = [
|
tags = [
|
||||||
"no_rocm",
|
"no_rocm",
|
||||||
"optonly",
|
"optonly",
|
||||||
|
@ -61,6 +61,44 @@ XLA_TEST_F(CholeskyTest, NonPSDInput) {
|
|||||||
ErrorSpec(1e-4, 1e-4));
|
ErrorSpec(1e-4, 1e-4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(CholeskyTest, NonPSDBatched) {
|
||||||
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
|
Array3D<float> a_vals({
|
||||||
|
{
|
||||||
|
{10, 0, 0},
|
||||||
|
{1, 20, 0},
|
||||||
|
{1, 1, 30},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{1, 1, 1},
|
||||||
|
{1, 1, 1},
|
||||||
|
{1, 1, 1},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
XlaOp a;
|
||||||
|
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||||
|
Cholesky(a, /*lower=*/true);
|
||||||
|
|
||||||
|
float nan = std::numeric_limits<float>::quiet_NaN();
|
||||||
|
Array3D<float> expected({
|
||||||
|
{
|
||||||
|
{3.16227766, 0., 0.},
|
||||||
|
{0.31622777, 4.4609416, 0.},
|
||||||
|
{0.31622777, 0.20175113, 5.46436606},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{nan, nan, nan},
|
||||||
|
{nan, nan, nan},
|
||||||
|
{nan, nan, nan},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
|
||||||
|
ErrorSpec(1e-4, 1e-4));
|
||||||
|
}
|
||||||
|
|
||||||
XLA_TEST_F(CholeskyTest, Lower) {
|
XLA_TEST_F(CholeskyTest, Lower) {
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
@ -181,7 +219,7 @@ class RandomCholeskyTest
|
|||||||
: public ClientLibraryTestBase,
|
: public ClientLibraryTestBase,
|
||||||
public ::testing::WithParamInterface<CholeskyTestCase> {};
|
public ::testing::WithParamInterface<CholeskyTestCase> {};
|
||||||
|
|
||||||
XLA_TEST_P(RandomCholeskyTest, Random) {
|
XLA_TEST_P(RandomCholeskyTest, Real) {
|
||||||
// Test fails with TensorFloat-32 enabled
|
// Test fails with TensorFloat-32 enabled
|
||||||
tensorflow::enable_tensor_float_32_execution(false);
|
tensorflow::enable_tensor_float_32_execution(false);
|
||||||
XlaBuilder builder(TestName());
|
XlaBuilder builder(TestName());
|
||||||
@ -220,14 +258,65 @@ XLA_TEST_P(RandomCholeskyTest, Random) {
|
|||||||
ErrorSpec(1e-4, 1e-4));
|
ErrorSpec(1e-4, 1e-4));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XLA_TEST_P(RandomCholeskyTest, Complex) {
|
||||||
|
// Test fails with TensorFloat-32 enabled
|
||||||
|
tensorflow::enable_tensor_float_32_execution(false);
|
||||||
|
XlaBuilder builder(TestName());
|
||||||
|
|
||||||
|
auto test_params = GetParam();
|
||||||
|
std::vector<int64> dimensions = {std::get<0>(test_params),
|
||||||
|
std::get<1>(test_params),
|
||||||
|
std::get<1>(test_params)};
|
||||||
|
bool lower = std::get<2>(test_params);
|
||||||
|
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto literal_real,
|
||||||
|
LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto literal_imag,
|
||||||
|
LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
|
||||||
|
|
||||||
|
auto input_real = Parameter(&builder, 0, shape, "input_real");
|
||||||
|
auto input_imag = Parameter(&builder, 1, shape, "input_imag");
|
||||||
|
auto input = Complex(input_real, input_imag);
|
||||||
|
// Form a random positive definite matrix.
|
||||||
|
auto matrix = BatchDot(input, TransposeInMinorDims(Conj(input)),
|
||||||
|
PrecisionConfig::HIGHEST);
|
||||||
|
|
||||||
|
auto cholesky = Triangle(Cholesky(matrix, lower), lower);
|
||||||
|
|
||||||
|
// Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
|
||||||
|
XlaOp verification;
|
||||||
|
if (lower) {
|
||||||
|
verification = BatchDot(cholesky, TransposeInMinorDims(Conj(cholesky)),
|
||||||
|
PrecisionConfig::HIGHEST);
|
||||||
|
} else {
|
||||||
|
verification = BatchDot(TransposeInMinorDims(Conj(cholesky)), cholesky,
|
||||||
|
PrecisionConfig::HIGHEST);
|
||||||
|
}
|
||||||
|
auto delta = matrix - verification;
|
||||||
|
Reduce(Abs(delta * Conj(delta)), ConstantR0<float>(&builder, 0.0),
|
||||||
|
CreateScalarAddComputation(F32, &builder), {0, 1, 2});
|
||||||
|
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto input_data_real,
|
||||||
|
client_->TransferToServer(literal_real));
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(auto input_data_imag,
|
||||||
|
client_->TransferToServer(literal_imag));
|
||||||
|
ComputeAndCompareR0<float>(&builder, 0.0,
|
||||||
|
{input_data_real.get(), input_data_imag.get()},
|
||||||
|
ErrorSpec(1e-4, 1e-4));
|
||||||
|
}
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
|
INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
|
||||||
::testing::Values(CholeskyTestCase{1, 1, true},
|
::testing::Values(CholeskyTestCase{1, 1, true},
|
||||||
CholeskyTestCase{1, 2, true},
|
CholeskyTestCase{1, 2, true},
|
||||||
CholeskyTestCase{1, 50, true},
|
CholeskyTestCase{1, 50, true},
|
||||||
CholeskyTestCase{1, 50, false},
|
CholeskyTestCase{1, 50, false},
|
||||||
|
CholeskyTestCase{1, 255, false},
|
||||||
CholeskyTestCase{10, 5, true},
|
CholeskyTestCase{10, 5, true},
|
||||||
CholeskyTestCase{5, 10, false},
|
CholeskyTestCase{5, 10, false},
|
||||||
CholeskyTestCase{2, 20, true}));
|
CholeskyTestCase{2, 20, true},
|
||||||
|
CholeskyTestCase{2, 129, true}));
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
Loading…
x
Reference in New Issue
Block a user