From 85404d1a8555948a49432c5a40be39c293026e56 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Fri, 18 Sep 2020 07:01:55 -0700 Subject: [PATCH] [XLA] Add complex number support to HLO cholesky decomposition implementation. Fix bug where errors in one batch element would cause other batch elements to fail. PiperOrigin-RevId: 332443397 Change-Id: I868accebbad9df2fa759525f6f0b0b3df6a481c1 --- .../compiler/xla/service/cholesky_expander.cc | 69 +++++++++----- tensorflow/compiler/xla/tests/BUILD | 1 + .../compiler/xla/tests/cholesky_test.cc | 93 ++++++++++++++++++- 3 files changed, 136 insertions(+), 27 deletions(-) diff --git a/tensorflow/compiler/xla/service/cholesky_expander.cc b/tensorflow/compiler/xla/service/cholesky_expander.cc index ffb0fb4e6ef..4abfe1b018e 100644 --- a/tensorflow/compiler/xla/service/cholesky_expander.cc +++ b/tensorflow/compiler/xla/service/cholesky_expander.cc @@ -56,17 +56,21 @@ StatusOr> CholeskyExpander::CholeskyUnblocked( XlaOp a, PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int n_dims = a_shape.rank(); + const int ndims = a_shape.rank(); const int64 n = ShapeUtil::GetDimension(a_shape, -1); + std::vector error_dims(a_shape.dimensions().begin(), + a_shape.dimensions().end()); + error_dims.back() = error_dims.at(ndims - 2) = 1; + auto major_dims = AsInt64Slice(a_shape.dimensions()) .subspan( /*pos=*/0, - /*len=*/n_dims - 2); + /*len=*/ndims - 2); auto matrix_dims = AsInt64Slice(a_shape.dimensions()) .subspan( /*pos=*/0, - /*len=*/n_dims); + /*len=*/ndims); XlaOp l = ZerosLike(a); @@ -79,9 +83,9 @@ StatusOr> CholeskyExpander::CholeskyUnblocked( auto body_l = loop_vars[1]; auto seen_error = loop_vars[2]; auto iota_row = - Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 1); + Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 1); auto iota_col = - Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), n_dims - 2); + Iota(body_builder, ShapeUtil::MakeShape(S32, matrix_dims), ndims - 2); auto mask_pred = Ge(iota_col, iota_row); mask_pred = And(mask_pred, Eq(iota_row, i)); @@ -91,25 +95,32 @@ StatusOr> CholeskyExpander::CholeskyUnblocked( // L * L.T, This matrix has of a lot of multiplying with zero // (namely, L[:, j:] = 0) and redundant computation, but it is faster // than slice. - auto l_square = BatchDot(body_l, false, body_l, true, precision); + auto l_square = + BatchDot(body_l, false, MaybeConjugate(body_l, true), true, precision); // A - L*L.T l_square = body_a - l_square; auto l_ii = DynamicSliceInMinorDims(l_square, {i, i}, {1, 1}); - l_ii = Sqrt(l_ii); + if (ShapeUtil::ElementIsComplex(a_shape)) { + auto sqrt = Sqrt(Real(l_ii)); + l_ii = Complex(sqrt, ZerosLike(sqrt)); + seen_error = Or(seen_error, IsNan(sqrt)); + } else { + l_ii = Sqrt(l_ii); + seen_error = Or(seen_error, IsNan(l_ii)); + } // L = (A - L*L.T) / l_ii * mask + L body_l = Select(mask_pred, l_square / l_ii, mask_zeros) + body_l; - seen_error = - Or(seen_error, Any(Or(Le(l_ii, ZerosLike(l_ii)), IsNan(l_ii)))); - return std::vector{body_a, body_l, seen_error}; }; TF_ASSIGN_OR_RETURN( auto cholesky_while, - ForEachIndex(n, S32, body_fn, {a, l, ConstantR0(builder, false)}, - "unblocked", builder)); + ForEachIndex( + n, S32, body_fn, + {a, l, Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims))}, + "unblocked", builder)); return std::make_pair(cholesky_while[1], cholesky_while[2]); } @@ -133,23 +144,23 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size, ShapeUtil::HumanString(a_shape)); } - if (primitive_util::IsComplexType(a_shape.element_type())) { - return Unimplemented( - "Complex types are not implemented in Cholesky; got shape %s", - ShapeUtil::HumanString(a_shape)); - } - if (block_size < 1) { return InvalidArgument( "block_size argument to Cholesky must be >= 1; got %d", block_size); } + std::vector error_dims(a_shape.dimensions().begin(), + a_shape.dimensions().end()); + error_dims.back() = error_dims.at(ndims - 2) = 1; + std::vector error_dim_indices(ndims); + absl::c_iota(error_dim_indices, 0); + // Blocked left-looking Cholesky factorization. // Algorithm 1 from // Haidar, Azzam, et al. "High-performance Cholesky factorization for // GPU-only execution." Proceedings of General Purpose GPUs. ACM, 2017. XlaOp l = ZerosLike(a); - XlaOp seen_error = ConstantR0(builder, false); + XlaOp seen_error = Zeros(builder, ShapeUtil::MakeShape(PRED, error_dims)); for (int64 i = 0; i < n; i += block_size) { int64 k = std::min(block_size, n - i); auto panel = SliceInMinorDims(a, {i, i}, {n, i + k}); @@ -159,7 +170,8 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size, // a[i:, i:i+k] -= np.dot(l[i:, :i], np.transpose(l[i:i+k, :i])) auto lhs = SliceInMinorDims(l, {i, 0}, {n, i}); auto rhs = SliceInMinorDims(l, {i, 0}, {i + k, i}); - auto delta = BatchDot(lhs, false, rhs, true, precision); + auto delta = + BatchDot(lhs, false, MaybeConjugate(rhs, true), true, precision); panel = panel - delta; } @@ -170,8 +182,14 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size, // other elements. XlaOp factorized_error; if (k == 1) { - factorized = Sqrt(x); - factorized_error = Any(IsNan(factorized)); + if (ShapeUtil::ElementIsComplex(a_shape)) { + auto sqrt = Sqrt(Real(x)); + factorized = Complex(sqrt, ZerosLike(sqrt)); + factorized_error = IsNan(sqrt); + } else { + factorized = Sqrt(x); + factorized_error = IsNan(factorized); + } } else { TF_ASSIGN_OR_RETURN(auto tile_output, CholeskyUnblocked(x, precision)); std::tie(factorized, factorized_error) = tile_output; @@ -187,12 +205,13 @@ XlaOp CholeskyExpander::BuildCholesky(XlaOp a, int64 block_size, /*left_side=*/false, /*lower=*/true, /*unit_diagonal=*/false, - /*transpose_a=*/TriangularSolveOptions::TRANSPOSE); + /*transpose_a=*/TriangularSolveOptions::ADJOINT); l = UpdateSliceInMinorDims(l, update, {i + k, i}); } } - return Select(seen_error, - FullLike(l, std::numeric_limits::quiet_NaN()), l); + return Select( + BroadcastInDim(seen_error, a_shape.dimensions(), error_dim_indices), + FullLike(l, std::numeric_limits::quiet_NaN()), l); }); } diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index d9110ed1f35..361d2065a00 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2676,6 +2676,7 @@ xla_test( xla_test( name = "cholesky_test", srcs = ["cholesky_test.cc"], + real_hardware_only = True, tags = [ "no_rocm", "optonly", diff --git a/tensorflow/compiler/xla/tests/cholesky_test.cc b/tensorflow/compiler/xla/tests/cholesky_test.cc index 616b404b425..4fa28736d4d 100644 --- a/tensorflow/compiler/xla/tests/cholesky_test.cc +++ b/tensorflow/compiler/xla/tests/cholesky_test.cc @@ -61,6 +61,44 @@ XLA_TEST_F(CholeskyTest, NonPSDInput) { ErrorSpec(1e-4, 1e-4)); } +XLA_TEST_F(CholeskyTest, NonPSDBatched) { + XlaBuilder builder(TestName()); + + Array3D a_vals({ + { + {10, 0, 0}, + {1, 20, 0}, + {1, 1, 30}, + }, + { + {1, 1, 1}, + {1, 1, 1}, + {1, 1, 1}, + }, + }); + + XlaOp a; + auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); + Cholesky(a, /*lower=*/true); + + float nan = std::numeric_limits::quiet_NaN(); + Array3D expected({ + { + {3.16227766, 0., 0.}, + {0.31622777, 4.4609416, 0.}, + {0.31622777, 0.20175113, 5.46436606}, + }, + { + {nan, nan, nan}, + {nan, nan, nan}, + {nan, nan, nan}, + }, + }); + + ComputeAndCompareR3(&builder, expected, {a_data.get()}, + ErrorSpec(1e-4, 1e-4)); +} + XLA_TEST_F(CholeskyTest, Lower) { XlaBuilder builder(TestName()); @@ -181,7 +219,7 @@ class RandomCholeskyTest : public ClientLibraryTestBase, public ::testing::WithParamInterface {}; -XLA_TEST_P(RandomCholeskyTest, Random) { +XLA_TEST_P(RandomCholeskyTest, Real) { // Test fails with TensorFloat-32 enabled tensorflow::enable_tensor_float_32_execution(false); XlaBuilder builder(TestName()); @@ -220,14 +258,65 @@ XLA_TEST_P(RandomCholeskyTest, Random) { ErrorSpec(1e-4, 1e-4)); } +XLA_TEST_P(RandomCholeskyTest, Complex) { + // Test fails with TensorFloat-32 enabled + tensorflow::enable_tensor_float_32_execution(false); + XlaBuilder builder(TestName()); + + auto test_params = GetParam(); + std::vector dimensions = {std::get<0>(test_params), + std::get<1>(test_params), + std::get<1>(test_params)}; + bool lower = std::get<2>(test_params); + Shape shape = ShapeUtil::MakeShape(F32, dimensions); + TF_ASSERT_OK_AND_ASSIGN( + auto literal_real, + LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); + TF_ASSERT_OK_AND_ASSIGN( + auto literal_imag, + LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); + + auto input_real = Parameter(&builder, 0, shape, "input_real"); + auto input_imag = Parameter(&builder, 1, shape, "input_imag"); + auto input = Complex(input_real, input_imag); + // Form a random positive definite matrix. + auto matrix = BatchDot(input, TransposeInMinorDims(Conj(input)), + PrecisionConfig::HIGHEST); + + auto cholesky = Triangle(Cholesky(matrix, lower), lower); + + // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0 + XlaOp verification; + if (lower) { + verification = BatchDot(cholesky, TransposeInMinorDims(Conj(cholesky)), + PrecisionConfig::HIGHEST); + } else { + verification = BatchDot(TransposeInMinorDims(Conj(cholesky)), cholesky, + PrecisionConfig::HIGHEST); + } + auto delta = matrix - verification; + Reduce(Abs(delta * Conj(delta)), ConstantR0(&builder, 0.0), + CreateScalarAddComputation(F32, &builder), {0, 1, 2}); + + TF_ASSERT_OK_AND_ASSIGN(auto input_data_real, + client_->TransferToServer(literal_real)); + TF_ASSERT_OK_AND_ASSIGN(auto input_data_imag, + client_->TransferToServer(literal_imag)); + ComputeAndCompareR0(&builder, 0.0, + {input_data_real.get(), input_data_imag.get()}, + ErrorSpec(1e-4, 1e-4)); +} + INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest, ::testing::Values(CholeskyTestCase{1, 1, true}, CholeskyTestCase{1, 2, true}, CholeskyTestCase{1, 50, true}, CholeskyTestCase{1, 50, false}, + CholeskyTestCase{1, 255, false}, CholeskyTestCase{10, 5, true}, CholeskyTestCase{5, 10, false}, - CholeskyTestCase{2, 20, true})); + CholeskyTestCase{2, 20, true}, + CholeskyTestCase{2, 129, true})); } // namespace } // namespace xla