From cf0f741491fd4cb3da3d71aea82619e26ff8f20b Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Tue, 26 Feb 2019 16:15:49 -0800 Subject: [PATCH] [XLA] Make Cholesky into a first-class HLO operator. Currently it is expanded into an HLO implementation on all backends. PiperOrigin-RevId: 235814360 --- tensorflow/compiler/tf2xla/kernels/BUILD | 1 - .../compiler/tf2xla/kernels/cholesky_op.cc | 7 +- tensorflow/compiler/xla/client/lib/BUILD | 41 ------ tensorflow/compiler/xla/client/lib/cholesky.h | 39 ----- tensorflow/compiler/xla/client/xla_builder.cc | 15 ++ tensorflow/compiler/xla/client/xla_builder.h | 18 ++- .../compiler/xla/g3doc/operation_semantics.md | 71 +++++++++ tensorflow/compiler/xla/python/BUILD | 1 - .../xla/python/local_computation_builder.cc | 5 +- .../xla/python/local_computation_builder.h | 2 +- tensorflow/compiler/xla/python/xla_client.py | 4 +- tensorflow/compiler/xla/service/BUILD | 23 +++ .../cholesky_expander.cc} | 60 +++++++- .../compiler/xla/service/cholesky_expander.h | 41 ++++++ tensorflow/compiler/xla/service/cpu/BUILD | 1 + .../compiler/xla/service/cpu/cpu_compiler.cc | 2 + .../compiler/xla/service/dfs_hlo_visitor.h | 1 + .../service/dfs_hlo_visitor_with_default.h | 3 + tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/nvptx_compiler.cc | 3 + tensorflow/compiler/xla/service/hlo.proto | 6 +- .../compiler/xla/service/hlo_cost_analysis.cc | 13 +- .../compiler/xla/service/hlo_cost_analysis.h | 1 + .../compiler/xla/service/hlo_graph_dumper.cc | 1 + .../compiler/xla/service/hlo_instruction.cc | 19 +++ .../compiler/xla/service/hlo_instruction.h | 6 + .../compiler/xla/service/hlo_instructions.cc | 68 +++++++-- .../compiler/xla/service/hlo_instructions.h | 25 ++++ tensorflow/compiler/xla/service/hlo_opcode.h | 1 + tensorflow/compiler/xla/service/hlo_parser.cc | 11 ++ .../compiler/xla/service/hlo_verifier.cc | 7 + .../compiler/xla/service/hlo_verifier.h | 1 + .../xla/service/instruction_fusion.cc | 1 + .../compiler/xla/service/interpreter/BUILD | 1 + .../xla/service/interpreter/compiler.cc | 2 + .../compiler/xla/service/layout_assignment.cc | 1 + .../compiler/xla/service/shape_inference.cc | 29 ++++ .../compiler/xla/service/shape_inference.h | 3 + tensorflow/compiler/xla/tests/BUILD | 20 +++ .../{client/lib => tests}/cholesky_test.cc | 136 +++++++++++------- tensorflow/compiler/xla/xla_data.proto | 6 + 41 files changed, 535 insertions(+), 162 deletions(-) delete mode 100644 tensorflow/compiler/xla/client/lib/cholesky.h rename tensorflow/compiler/xla/{client/lib/cholesky.cc => service/cholesky_expander.cc} (76%) create mode 100644 tensorflow/compiler/xla/service/cholesky_expander.h rename tensorflow/compiler/xla/{client/lib => tests}/cholesky_test.cc (54%) diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index ffdc5a2b40b..eb0aeb8dcc0 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -136,7 +136,6 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:constants", "//tensorflow/compiler/xla/client/lib:loops", "//tensorflow/compiler/xla/client/lib:math", diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index 0ed3044efa5..e6b30a38e03 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -15,7 +15,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/lib/cholesky.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" namespace tensorflow { namespace { @@ -24,7 +25,9 @@ class CholeskyOp : public XlaOpKernel { public: explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - ctx->SetOutput(0, xla::Cholesky(ctx->Input(0))); + ctx->SetOutput(0, + xla::Triangle(xla::Cholesky(ctx->Input(0), /*lower=*/true), + /*lower=*/true)); } }; diff --git a/tensorflow/compiler/xla/client/lib/BUILD b/tensorflow/compiler/xla/client/lib/BUILD index f264ec50a26..d24a0d652a4 100644 --- a/tensorflow/compiler/xla/client/lib/BUILD +++ b/tensorflow/compiler/xla/client/lib/BUILD @@ -49,47 +49,6 @@ xla_test( ], ) -cc_library( - name = "cholesky", - srcs = ["cholesky.cc"], - hdrs = ["cholesky.h"], - deps = [ - ":math", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/client/lib:constants", - "//tensorflow/compiler/xla/client/lib:loops", - "//tensorflow/compiler/xla/client/lib:matrix", - "//tensorflow/compiler/xla/client/lib:slicing", - "//tensorflow/core:lib", - ], -) - -xla_test( - name = "cholesky_test", - srcs = ["cholesky_test.cc"], - tags = ["optonly"], - deps = [ - ":arithmetic", - ":cholesky", - ":matrix", - "//tensorflow/compiler/xla:array2d", - "//tensorflow/compiler/xla:literal", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:test", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla/client:xla_builder", - "//tensorflow/compiler/xla/tests:client_library_test_base", - "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", - "//tensorflow/core:test", - ], -) - cc_library( name = "comparators", srcs = ["comparators.cc"], diff --git a/tensorflow/compiler/xla/client/lib/cholesky.h b/tensorflow/compiler/xla/client/lib/cholesky.h deleted file mode 100644 index 0bae26837c0..00000000000 --- a/tensorflow/compiler/xla/client/lib/cholesky.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ -#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ - -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { - -// Computes the Cholesky decompositions of a batch of symmetric positive -// definite matrices. -// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the -// two minor dimensions equal. -// The algorithm implements a blocked Cholesky decomposition; `block_size` is -// the block size to use. -// TODO(phawkins): check for negative values on the diagonal and return an -// error, instead of silently yielding NaNs. -// TODO(znado): handle the complex Hermitian case -xla::XlaOp Cholesky( - xla::XlaOp a, int64 block_size = 256, - xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); - -} // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_ diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 16381155c3f..7d42f910b82 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -3022,6 +3022,21 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, }); } +XlaOp Cholesky(XlaOp a, bool lower) { + XlaBuilder* builder = a.builder(); + return builder->ReportErrorOrReturn([&]() -> StatusOr { + HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& a_shape, builder->GetShape(a)); + xla::CholeskyOptions& options = *instr.mutable_cholesky_options(); + options.set_lower(lower); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferCholeskyShape(a_shape)); + *instr.mutable_shape() = shape.ToProto(); + + return builder->AddInstruction(std::move(instr), HloOpcode::kCholesky, {a}); + }); +} + XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) { return builder->Infeed(shape, config); } diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 129e5167429..9f0a9adf81f 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -801,6 +801,7 @@ class XlaBuilder { friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, bool unit_diagonal, TriangularSolveOptions::Transpose transpose_a); + friend XlaOp Cholesky(XlaOp a, bool lower); friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config); friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout, @@ -1342,8 +1343,7 @@ XlaOp Fft(const XlaOp& operand, FftType fft_type, // * `left_side` is a boolean, indicating whether to solve a system of the form // op(a) * x = b (true) or x * op(a) = b (false). // * `lower` is a boolean, indicating whether the argument `a` is -// lower-triangular -// (true) or upper-triangular (false). +// lower-triangular (true) or upper-triangular (false). // * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be // 1 and not accessed. // * `transpose_a` indicates which function `op` we use to transform the tensor @@ -1352,6 +1352,20 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower, bool unit_diagonal, TriangularSolveOptions::Transpose transpose_a); +// Computes the Cholesky decompositions of a batch of symmetric (Hermitian) +// positive definite matrices. +// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the +// two minor dimensions equal. +// If `lower` is true, the data from the lower triangle is used; if false, the +// upper triangle is used. The input data in the other triangle of the input +// does not affect the output. Returns the output in the same lower/uppper +// triangle. The data returned in the other output triangle is arbitrary and +// implementation-defined. +// +// The value returned if `a` is not Hermitian positive definite is +// implementation-defined. +XlaOp Cholesky(XlaOp a, bool lower); + // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. XlaOp Infeed(XlaBuilder* builder, const Shape& shape, diff --git a/tensorflow/compiler/xla/g3doc/operation_semantics.md b/tensorflow/compiler/xla/g3doc/operation_semantics.md index db90d184b52..5c384580d58 100644 --- a/tensorflow/compiler/xla/g3doc/operation_semantics.md +++ b/tensorflow/compiler/xla/g3doc/operation_semantics.md @@ -322,6 +322,37 @@ Invokes a computation with the given arguments. The arity and types of the `args` must match the parameters of the `computation`. It is allowed to have no `args`. +## Cholesky + +See also +[`XlaBuilder::Cholesky`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Computes the +[Cholesky decomposition](https://en.wikipedia.org/wiki/Cholesky_decomposition) +of a batch of symmetric (Hermitian) positive definite matrices. + + `Cholesky(a, lower)` + +Arguments | Type | Semantics +--------- | ------- | ----------------------------------------------------- +`a` | `XlaOp` | a rank > 2 array of a complex or floating-point type. +`lower` | `bool` | whether to use the upper or lower triangle of `a`. + +If `lower` is `true`, computes lower-triangular matrices `l` such that $$ a = l +. l^T $$. If `lower` is `false`, computes upper-triangular matrices `u` such +that $$ a = u^T . u $$. + +Input data is read only from the lower/upper triangle of `a`, depending on the +value of `lower`. Values from the other triangle are ignored. Output data is +returned in the same triangle; the values in the other triangle are +implementation-defined and may be anything. + +If the rank of `a` is greater than 2, `a` is treated as a batch of matrices, +where all except the minor 2 dimensions are batch dimensions. + +If `a` is not symmetric (Hermitian) positive definite, the result is +implementation-defined. + ## Clamp See also @@ -2439,6 +2470,46 @@ Permutes the operand dimensions with the given permutation, so This is the same as Reshape(operand, permutation, Permute(permutation, operand.shape.dimensions)). +## TriangularSolve + +See also +[`XlaBuilder::TriangularSolve`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h). + +Solves systems of linear equations with lower or upper triangular coefficient +matrices by forward- or back-substitution. Broadcasting along leading +dimensions, this routine solves one of the matrix systems `op(a) * x = +b`, or `x * op(a) = b`, for the variable `x`, given `a` and `b`, where `op(a)` is +either `op(a) = a`, or `op(a) = Transpose(a)`, or `op(a) = Conj(Transpose(a))`. + + `TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)` + +| Arguments | Type | Semantics | +| --------------- | ----------- | -------------------------------------------- | +| `a` | `XlaOp` | a rank > 2 array of a complex or | +: : : floating-point type with shape `[..., M, : +: : : M]`. : +| `b` | `XlaOp` | a rank > 2 array of the same type with shape | +: : : `[..., M, K]` if `left_side` is true, `[..., : +: : : K, M]` otherwise. : +| `left_side` | `bool` | indicates whether to solve a system of the | +: : : form `op(a) * x = b` (`true`) or `x * : +: : : op(a) = b` (`false`). : +| `lower` | `bool` | whether to use the upper or lower triangle | +: : : of `a`. : +| `unit_diagonal` | `bool` | if `true`, the diagonal elements of `a` are | +: : : assumed to be `1` and not accessed. : +| `transpose_a` | `Transpose` | whether to use `a` as is, transpose it or | +: : : take its conjugate transpose. : + +Input data is read only from the lower/upper triangle of `a`, depending on the +value of `lower`. Values from the other triangle are ignored. Output data is +returned in the same triangle; the values in the other triangle are +implementation-defined and may be anything. + +If the rank of `a` and `b` are greater than 2, they are treated as batches of +matrices, where all except the minor 2 dimensions are batch dimensions. `a` and +`b` must have equal batch dimensions. + ## Tuple See also diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD index 92ace7877ed..68128a3097f 100644 --- a/tensorflow/compiler/xla/python/BUILD +++ b/tensorflow/compiler/xla/python/BUILD @@ -69,7 +69,6 @@ cc_library( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", - "//tensorflow/compiler/xla/client/lib:cholesky", "//tensorflow/compiler/xla/client/lib:math", "//tensorflow/compiler/xla/client/lib:qr", "//tensorflow/compiler/xla/service:computation_placer", diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 8f7b105eefe..5cfbb2c20df 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -21,7 +21,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/lib/cholesky.h" #include "tensorflow/compiler/xla/client/lib/math.h" #include "tensorflow/compiler/xla/client/lib/qr.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -714,8 +713,8 @@ LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys, return xla::Sort(keys.op(), {values.op()}, dimension); } -LocalOp ComputationBuilder::Cholesky(const LocalOp& a) { - return xla::Cholesky(a.op()); +LocalOp ComputationBuilder::Cholesky(const LocalOp& a, bool lower) { + return xla::Cholesky(a.op(), lower); } LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) { diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h index 41ae4f73708..fa878501aba 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.h +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -358,7 +358,7 @@ class ComputationBuilder { LocalOp QR(const LocalOp& a, bool full_matrices); - LocalOp Cholesky(const LocalOp& a); + LocalOp Cholesky(const LocalOp& a, bool lower); // `transpose_a` is the integer value of a TriangularSolveOptions::Transpose // enum. We use an integer here so we don't have to teach SWIG about the diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py index ccea1348944..38458d7f090 100644 --- a/tensorflow/compiler/xla/python/xla_client.py +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -1735,9 +1735,9 @@ class ComputationBuilder(object): """Enqueues a key-value sort operation onto the computation.""" return self._client.SortKeyVal(keys, values, dimension) - def Cholesky(self, a): + def Cholesky(self, a, lower=True): """Enqueues a Cholesky decomposition onto the computation.""" - return self._client.Cholesky(a) + return self._client.Cholesky(a, lower) def QR(self, a, full_matrices=True): """Enqueues a QR decomposition onto the computation.""" diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 4a668d0c28a..3b5a91c0463 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1581,6 +1581,29 @@ cc_library( ], ) +cc_library( + name = "cholesky_expander", + srcs = ["cholesky_expander.cc"], + hdrs = ["cholesky_expander.h"], + deps = [ + ":op_expander_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + tf_cc_test( name = "batchnorm_expander_test", size = "small", diff --git a/tensorflow/compiler/xla/client/lib/cholesky.cc b/tensorflow/compiler/xla/service/cholesky_expander.cc similarity index 76% rename from tensorflow/compiler/xla/client/lib/cholesky.cc rename to tensorflow/compiler/xla/service/cholesky_expander.cc index bb41f9932d1..1c39cf9bc0a 100644 --- a/tensorflow/compiler/xla/client/lib/cholesky.cc +++ b/tensorflow/compiler/xla/service/cholesky_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/cholesky.h" +#include "tensorflow/compiler/xla/service/cholesky_expander.h" #include #include @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -134,10 +135,8 @@ XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) { }); } -} // namespace - -XlaOp Cholesky(XlaOp a, int64 block_size, - PrecisionConfig::Precision precision) { +XlaOp BuildCholesky(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder(); return builder->ReportErrorOrReturn([&]() -> StatusOr { TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); @@ -206,4 +205,55 @@ XlaOp Cholesky(XlaOp a, int64 block_size, }); } +} // namespace + +bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCholesky; +} + +StatusOr CholeskyExpander::ExpandInstruction( + HloInstruction* instruction) { + const CholeskyOptions& options = instruction->cholesky_options(); + const string name = absl::StrFormat( + "xla.cholesky_%s_%s", instruction->operand(0)->shape().ToString(), + options.lower() ? "lower" : "upper"); + + HloModule* module = instruction->parent()->parent(); + + HloComputation*& computation = + computation_cache_.emplace(name, nullptr).first->second; + if (!computation) { + // Builds a new expansion. + // + // TODO(b/62327888): We do something unusual here: we build the computation + // using the XlaBuilder API, which is nominally an XLA client API. We do + // this because the external APIs for building complicated computations + // (XlaBuilder) are much more ergonomic than the internal ones. As it turns + // out, XlaBuilder isn't really a client API—what it does is build a + // HloModuleProto protocol buffer, that we can then deserialize and clone + // into our HloModule. Ideally we would avoid the protocol buffer step; + // that is left as an exercise for future work. + XlaBuilder builder(name); + XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a"); + XlaOp l = BuildCholesky(MaybeTransposeInMinorDims(a, !options.lower()), + /*block_size=*/128, + /*precision=*/PrecisionConfig::HIGHEST); + MaybeTransposeInMinorDims(l, !options.lower()); + + TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build()); + + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + xla_computation.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( + xla_computation.proto(), config)); + HloCloneContext context(module); + computation = + module->DeepCloneComputation(new_module->entry_computation(), &context); + } + + return instruction->parent()->AddInstruction(HloInstruction::CreateCall( + instruction->shape(), instruction->operands(), computation)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/cholesky_expander.h b/tensorflow/compiler/xla/service/cholesky_expander.h new file mode 100644 index 00000000000..d2958db1b8c --- /dev/null +++ b/tensorflow/compiler/xla/service/cholesky_expander.h @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +namespace xla { + +class CholeskyExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "cholesky_expander"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + private: + // Mapping from op signatures to existing computations. + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 42672bc3875..c8fef147b85 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -111,6 +111,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:cholesky_expander", "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 19ab3bddb56..eb5d843fe8b 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -50,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/cholesky_expander.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h" @@ -257,6 +258,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 2f7fddb96da..246f2af09b5 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -112,6 +112,7 @@ class DfsHloVisitorBase { virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0; + virtual Status HandleCholesky(HloInstructionPtr hlo) = 0; virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0; virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0; virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 341bb37b835..79ce3f82e8c 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -94,6 +94,9 @@ class DfsHloVisitorWithDefaultBase Status HandleTriangularSolve(HloInstructionPtr hlo) override { return DefaultAction(hlo); } + Status HandleCholesky(HloInstructionPtr hlo) override { + return DefaultAction(hlo); + } Status HandleAllReduce(HloInstructionPtr crs) override { return DefaultAction(crs); } diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 5b5ad63ec94..90bb317e3aa 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -779,6 +779,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:cholesky_expander", "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 27f2cb3bdf4..03ddb8266b8 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/cholesky_expander.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" #include "tensorflow/compiler/xla/service/convolution_group_converter.h" #include "tensorflow/compiler/xla/service/dot_decomposer.h" @@ -187,6 +188,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, &pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); + pipeline.AddPass(); + // TODO(b/64094172): make Call work on GPU instead of inlining. pipeline.AddPass(); auto cost_model = [](HloInstruction* conv) { diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index ae9e3169fd9..1413ce3062d 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -29,12 +29,13 @@ limitations under the License. syntax = "proto3"; package xla; + import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 62 +// Next ID: 63 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -200,6 +201,9 @@ message HloInstructionProto { // Options for TriangularSolve xla.TriangularSolveOptions triangular_solve_options = 59; + // Options for Cholesky + xla.CholeskyOptions cholesky_options = 62; + // Describes how parameters behave with regards to replicas. xla.ParameterReplication parameter_replication = 61; } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 6d9e01e3a77..3a95629e93e 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -557,11 +557,22 @@ Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) { // Estimate as batch * mn^2 / 2 flops. int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1); elems *= ShapeUtil::ElementsIn(b_shape); - // Each output elment requires reduction_widht FMA operations. current_properties_[kFlopsKey] = kFmaFlops * elems; return Status::OK(); } +Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) { + float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f; + current_properties_[kBytesAccessedKey] = bytes_accessed; + + const Shape& a_shape = hlo->operand(0)->shape(); + // Estimate as batch * n^3 / 3 flops. + int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1); + elems *= ShapeUtil::ElementsIn(a_shape); + current_properties_[kFlopsKey] = elems / 3; + return Status::OK(); +} + Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index 96357dec68e..4480554de50 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -72,6 +72,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleConvolution(const HloInstruction* convolution) override; Status HandleFft(const HloInstruction* fft) override; Status HandleTriangularSolve(const HloInstruction* hlo) override; + Status HandleCholesky(const HloInstruction* hlo) override; Status HandleAllReduce(const HloInstruction* crs) override; Status HandleAllToAll(const HloInstruction* hlo) override; Status HandleCollectivePermute(const HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 1bb4195b7e6..9623edcf5eb 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -1017,6 +1017,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kDot: case HloOpcode::kFft: case HloOpcode::kTriangularSolve: + case HloOpcode::kCholesky: return kDarkBlue; case HloOpcode::kReducePrecision: return kRed; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 6c47bb8935a..b41dc79d8de 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -132,6 +132,11 @@ StatusOr> HloInstruction::CreateFromProto( proto.triangular_solve_options()); break; } + case HloOpcode::kCholesky: { + instruction = + CreateCholesky(shape, operands(0), proto.cholesky_options()); + break; + } case HloOpcode::kSend: instruction = CreateSend(operands(0), operands(1), proto.channel_id(), proto.is_host_transfer()); @@ -734,6 +739,11 @@ HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a, return absl::make_unique(shape, a, b, options); } +/* static */ std::unique_ptr HloInstruction::CreateCholesky( + const Shape& shape, HloInstruction* a, const CholeskyOptions& options) { + return absl::make_unique(shape, a, options); +} + /* static */ std::unique_ptr HloInstruction::CreateDot( const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, const DotDimensionNumbers& dimension_numbers, @@ -1294,6 +1304,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kDomain: case HloOpcode::kGetDimensionSize: case HloOpcode::kTriangularSolve: + case HloOpcode::kCholesky: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1754,6 +1765,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kDomain: case HloOpcode::kGetDimensionSize: case HloOpcode::kTriangularSolve: + case HloOpcode::kCholesky: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } @@ -2553,6 +2565,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleGetDimensionSize(this); case HloOpcode::kTriangularSolve: return visitor->HandleTriangularSolve(this); + case HloOpcode::kCholesky: + return visitor->HandleCholesky(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -3437,4 +3451,9 @@ const DomainMetadata& HloInstruction::user_side_metadata() const { const TriangularSolveOptions& HloInstruction::triangular_solve_options() const { return Cast(this)->triangular_solve_options(); } + +const CholeskyOptions& HloInstruction::cholesky_options() const { + return Cast(this)->cholesky_options(); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 33cbb9a41ba..f868a9e9018 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -448,6 +448,9 @@ class HloInstruction { const Shape& shape, HloInstruction* a, HloInstruction* b, const TriangularSolveOptions& options); + static std::unique_ptr CreateCholesky( + const Shape& shape, HloInstruction* a, const CholeskyOptions& options); + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch // dimensions specified in 'dimension_numbers'. static std::unique_ptr CreateDot( @@ -1594,6 +1597,9 @@ class HloInstruction { // Delegates to HloTriangularSolveInstruction::triangular_solve_options(). const TriangularSolveOptions& triangular_solve_options() const; + // Delegates to HloCholeskyInstruction::cholesky_options(). + const CholeskyOptions& cholesky_options() const; + // Old methods kept for smooth subclassing transition END. protected: diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index f8eef78531d..7d18b35c2bb 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -202,21 +202,6 @@ std::unique_ptr HloFftInstruction::CloneWithNewOperandsImpl( fft_length_); } -HloTriangularSolveInstruction::HloTriangularSolveInstruction( - const Shape& shape, HloInstruction* a, HloInstruction* b, - const TriangularSolveOptions& options) - : HloInstruction(HloOpcode::kTriangularSolve, shape), - triangular_solve_options_(options) { - AppendOperand(a); - AppendOperand(b); -} - -HloInstructionProto HloTriangularSolveInstruction::ToProto() const { - HloInstructionProto proto = HloInstruction::ToProto(); - *proto.mutable_triangular_solve_options() = triangular_solve_options_; - return proto; -} - namespace { // Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector @@ -257,6 +242,21 @@ std::vector AttributeProtoToStringVector( } // namespace +HloTriangularSolveInstruction::HloTriangularSolveInstruction( + const Shape& shape, HloInstruction* a, HloInstruction* b, + const TriangularSolveOptions& options) + : HloInstruction(HloOpcode::kTriangularSolve, shape), + triangular_solve_options_(options) { + AppendOperand(a); + AppendOperand(b); +} + +HloInstructionProto HloTriangularSolveInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_triangular_solve_options() = triangular_solve_options_; + return proto; +} + std::vector HloTriangularSolveInstruction::ExtraAttributesToStringImpl( const HloPrintOptions& options) const { return AttributeProtoToStringVector(triangular_solve_options_); @@ -286,6 +286,44 @@ HloTriangularSolveInstruction::CloneWithNewOperandsImpl( shape, new_operands[0], new_operands[1], triangular_solve_options()); } +HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape, + HloInstruction* a, + const CholeskyOptions& options) + : HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) { + AppendOperand(a); +} + +HloInstructionProto HloCholeskyInstruction::ToProto() const { + HloInstructionProto proto = HloInstruction::ToProto(); + *proto.mutable_cholesky_options() = cholesky_options_; + return proto; +} + +std::vector HloCholeskyInstruction::ExtraAttributesToStringImpl( + const HloPrintOptions& options) const { + return AttributeProtoToStringVector(cholesky_options_); +} + +bool HloCholeskyInstruction::IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const { + const auto& casted_other = static_cast(other); + const auto& options = cholesky_options(); + const auto& other_options = casted_other.cholesky_options(); + + return options.lower() == other_options.lower(); +} + +std::unique_ptr +HloCholeskyInstruction::CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const { + CHECK_EQ(new_operands.size(), 1); + return absl::make_unique(shape, new_operands[0], + cholesky_options()); +} + HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode, const Shape& shape, int64 channel_id, diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 4d23cb671f2..43aa12c10f2 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -159,6 +159,31 @@ class HloTriangularSolveInstruction : public HloInstruction { TriangularSolveOptions triangular_solve_options_; }; +class HloCholeskyInstruction : public HloInstruction { + public: + explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a, + const CholeskyOptions& options); + const CholeskyOptions& cholesky_options() const { return cholesky_options_; } + + // Returns a serialized representation of this instruction. + HloInstructionProto ToProto() const override; + + private: + std::vector ExtraAttributesToStringImpl( + const HloPrintOptions& options) const override; + bool IdenticalSlowPath( + const HloInstruction& other, + const std::function& + eq_computations) const override; + + // Implementation for non-common logic of CloneWithNewOperands. + std::unique_ptr CloneWithNewOperandsImpl( + const Shape& shape, absl::Span new_operands, + HloCloneContext* context) const override; + + CholeskyOptions cholesky_options_; +}; + class HloSendRecvInstruction : public HloInstruction { public: // Returns the channel id associated with the instruction. The id is diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index c571664c812..973f62731cb 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -61,6 +61,7 @@ namespace xla { V(kBroadcast, "broadcast", 1) \ V(kCall, "call", kHloOpcodeIsVariadic) \ V(kCeil, "ceil", 1) \ + V(kCholesky, "cholesky", 1) \ V(kClamp, "clamp", 3) \ V(kCollectivePermute, "collective-permute", 1) \ V(kClz, "count-leading-zeros", 1) \ diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 601748e54b6..52007d72cc8 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1129,6 +1129,17 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder, shape, operands[0], operands[1], options)); break; } + case HloOpcode::kCholesky: { + CholeskyOptions options; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributesAsProtoMessage( + /*required_attrs=*/std::unordered_set(), &options)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCholesky(shape, operands[0], options)); + break; + } case HloOpcode::kBroadcast: { optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 56a06a182a2..cec07d5c09f 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -177,6 +177,13 @@ Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) { return CheckShape(hlo, expected); } +Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) { + TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1)); + TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape( + hlo->operand(0)->shape())); + return CheckShape(hlo, expected); +} + Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) { std::vector operand_shapes; for (const HloInstruction* operand : crs->operands()) { diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index a9b5e9a3e6e..d427a1586c3 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -52,6 +52,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; Status HandleFft(HloInstruction* fft) override; + Status HandleCholesky(HloInstruction* hlo) override; Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleAllReduce(HloInstruction* crs) override; Status HandleAllToAll(HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index f5770eee225..e4a78af7c72 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -126,6 +126,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) { case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormTraining: case HloOpcode::kCall: + case HloOpcode::kCholesky: case HloOpcode::kConditional: case HloOpcode::kConvolution: case HloOpcode::kAllReduce: diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 8cd93626899..599489b3785 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -32,6 +32,7 @@ cc_library( "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/service:algebraic_simplifier", + "//tensorflow/compiler/xla/service:cholesky_expander", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_placer", "//tensorflow/compiler/xla/service:dynamic_index_splitter", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index 792773c6769..a8f8ab4f725 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/cholesky_expander.h" #include "tensorflow/compiler/xla/service/computation_placer.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h" @@ -80,6 +81,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass( hlo_module->mutable_entry_computation_layout(), diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index aa791ea195e..984ce892519 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -2073,6 +2073,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kTriangularSolve: + case HloOpcode::kCholesky: case HloOpcode::kTupleSelect: case HloOpcode::kWhile: return false; diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 431c2e3a5e0..cd65071b8ba 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -1911,6 +1911,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, /* static */ StatusOr ShapeInference::InferTriangularSolveShape( const Shape& a, const Shape& b, const TriangularSolveOptions& options) { + if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) || + a.element_type() != b.element_type()) { + return InvalidArgument( + "Expected element types in shape to be floating or complex and " + "identical for TriangularSolve; got %s and %s.", + PrimitiveType_Name(a.element_type()), + PrimitiveType_Name(b.element_type())); + } if (a.rank() < 2) { return InvalidArgument( "The 'a' argument to TriangularSolve must have rank >= 2, got shape %s", @@ -1952,6 +1960,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation, return b; } +/* static */ StatusOr ShapeInference::InferCholeskyShape( + const Shape& a) { + if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) { + return InvalidArgument( + "Expected element type in shape to be floating or complex for " + "Cholesky; got %s.", + PrimitiveType_Name(a.element_type())); + } + if (a.rank() < 2) { + return InvalidArgument( + "The 'a' argument to Cholesky must have rank >= 2, got shape %s", + a.ToString()); + } + if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) { + return InvalidArgument( + "The two minor dimensions of 'a' must have equal size, got %s.", + a.ToString()); + } + return a; +} + /* static */ StatusOr ShapeInference::InferAllReduceShape( absl::Span operand_shapes) { for (const Shape* operand_shape : operand_shapes) { diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index acb071ab188..f2fec695060 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -120,6 +120,9 @@ class ShapeInference { static StatusOr InferTriangularSolveShape( const Shape& a, const Shape& b, const TriangularSolveOptions& options); + // Infers the shape produced by the given triangular solve operation. + static StatusOr InferCholeskyShape(const Shape& a); + // Infers the shape produced by a cross replica sum with the given operand // shapes. static StatusOr InferAllReduceShape( diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index 79a5b7539db..bce1210b389 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -2217,3 +2217,23 @@ xla_test( "//tensorflow/core:test", ], ) + +xla_test( + name = "cholesky_test", + srcs = ["cholesky_test.cc"], + tags = ["optonly"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) diff --git a/tensorflow/compiler/xla/client/lib/cholesky_test.cc b/tensorflow/compiler/xla/tests/cholesky_test.cc similarity index 54% rename from tensorflow/compiler/xla/client/lib/cholesky_test.cc rename to tensorflow/compiler/xla/tests/cholesky_test.cc index 095dd4fbf8b..272d5784362 100644 --- a/tensorflow/compiler/xla/client/lib/cholesky_test.cc +++ b/tensorflow/compiler/xla/tests/cholesky_test.cc @@ -13,8 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/client/lib/cholesky.h" - +#include #include #include #include @@ -32,27 +31,27 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" +namespace xla { namespace { -using xla::int64; +using CholeskyTest = ClientLibraryTestBase; -using CholeskyTest = xla::ClientLibraryTestBase; +XLA_TEST_F(CholeskyTest, Lower) { + XlaBuilder builder(TestName()); -XLA_TEST_F(CholeskyTest, Simple) { - xla::XlaBuilder builder(TestName()); - - xla::Array2D a_vals({ - {4, 6, 8, 10}, - {6, 45, 54, 63}, - {8, 54, 146, 166}, + float nan = std::numeric_limits::quiet_NaN(); + Array2D a_vals({ + {4, nan, nan, nan}, + {6, 45, nan, nan}, + {8, 54, 146, nan}, {10, 63, 166, 310}, }); - xla::XlaOp a; + XlaOp a; auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); - xla::Cholesky(a, /*block_size=*/2); + LowerTriangle(Cholesky(a, /*lower=*/true)); - xla::Array2D expected({ + Array2D expected({ {2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, @@ -60,34 +59,62 @@ XLA_TEST_F(CholeskyTest, Simple) { }); ComputeAndCompareR2(&builder, expected, {a_data.get()}, - xla::ErrorSpec(1e-4, 1e-4)); + ErrorSpec(1e-4, 1e-4)); +} + +XLA_TEST_F(CholeskyTest, Upper) { + XlaBuilder builder(TestName()); + + float nan = std::numeric_limits::quiet_NaN(); + Array2D a_vals({ + {4, 6, 8, 10}, + {nan, 45, 54, 63}, + {nan, nan, 146, 166}, + {nan, nan, nan, 310}, + }); + + XlaOp a; + auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); + UpperTriangle(Cholesky(a, /*lower=*/false)); + + Array2D expected({ + {2, 3, 4, 5}, + {0, 6, 7, 8}, + {0, 0, 9, 10}, + {0, 0, 0, 11}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get()}, + ErrorSpec(1e-4, 1e-4)); } XLA_TEST_F(CholeskyTest, Simple2) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::Array2D a_vals({ + Array2D a_vals({ {16, 24, 8, 12}, {24, 61, 82, 48}, {8, 82, 456, 106}, {12, 48, 106, 62}, }); - xla::XlaOp a; + XlaOp a; auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); - xla::Cholesky(a); + LowerTriangle(Cholesky(a, /*lower=*/true)); - xla::Array2D expected( - {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}); + Array2D expected({{4, 0, 0, 0}, // + {6, 5, 0, 0}, // + {2, 14, 16, 0}, // + {3, 6, 1, 4}}); ComputeAndCompareR2(&builder, expected, {a_data.get()}, - xla::ErrorSpec(1e-4, 1e-4)); + ErrorSpec(1e-4, 1e-4)); } XLA_TEST_F(CholeskyTest, SimpleBatched) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); - xla::Array3D a_vals({ + Array3D a_vals({ { {4, 6, 8, 10}, {6, 45, 54, 63}, @@ -102,65 +129,78 @@ XLA_TEST_F(CholeskyTest, SimpleBatched) { }, }); - xla::XlaOp a; + XlaOp a; auto a_data = CreateR3Parameter(a_vals, 0, "a", &builder, &a); - xla::Cholesky(a); + LowerTriangle(Cholesky(a, /*lower=*/true)); - xla::Array3D expected({ + Array3D expected({ { {2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}, }, - {{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}}, + {{4, 0, 0, 0}, // + {6, 5, 0, 0}, // + {2, 14, 16, 0}, // + {3, 6, 1, 4}}, }); ComputeAndCompareR3(&builder, expected, {a_data.get()}, - xla::ErrorSpec(1e-4, 1e-4)); + ErrorSpec(1e-4, 1e-4)); } -using CholeskyTestCase = std::tuple; +using CholeskyTestCase = std::tuple; class RandomCholeskyTest - : public xla::ClientLibraryTestBase, + : public ClientLibraryTestBase, public ::testing::WithParamInterface {}; XLA_TEST_P(RandomCholeskyTest, Random) { - xla::XlaBuilder builder(TestName()); + XlaBuilder builder(TestName()); auto test_params = GetParam(); std::vector dimensions = {std::get<0>(test_params), std::get<1>(test_params), std::get<1>(test_params)}; - xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions); + bool lower = std::get<2>(test_params); + Shape shape = ShapeUtil::MakeShape(F32, dimensions); TF_ASSERT_OK_AND_ASSIGN( - auto literal, - xla::LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); + auto literal, LiteralUtil::CreateRandomLiteral(shape, 0.0, 1.0)); - auto input = xla::Parameter(&builder, 0, shape, "input"); + auto input = Parameter(&builder, 0, shape, "input"); // Form a random positive definite matrix. - auto matrix = xla::BatchDot(input, TransposeInMinorDims(input), - xla::PrecisionConfig::HIGHEST); + auto matrix = + BatchDot(input, TransposeInMinorDims(input), PrecisionConfig::HIGHEST); - auto cholesky = xla::Cholesky(matrix, /*block_size=*/4); + auto cholesky = Triangle(Cholesky(matrix, lower), lower); // Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0 - auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky), - xla::PrecisionConfig::HIGHEST); + XlaOp verification; + if (lower) { + verification = BatchDot(cholesky, TransposeInMinorDims(cholesky), + PrecisionConfig::HIGHEST); + } else { + verification = BatchDot(TransposeInMinorDims(cholesky), cholesky, + PrecisionConfig::HIGHEST); + } auto delta = matrix - verification; - xla::Reduce(delta * delta, xla::ConstantR0(&builder, 0.0), - CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2}); + Reduce(delta * delta, ConstantR0(&builder, 0.0), + CreateScalarAddComputation(F32, &builder), {0, 1, 2}); TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal)); ComputeAndCompareR0(&builder, 0.0, {input_data.get()}, - xla::ErrorSpec(1e-4, 1e-4)); + ErrorSpec(1e-4, 1e-4)); } INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest, - ::testing::Values(CholeskyTestCase{1, 1}, - CholeskyTestCase{1, 2}, - CholeskyTestCase{10, 5}, - CholeskyTestCase{2, 20})); + ::testing::Values(CholeskyTestCase{1, 1, true}, + CholeskyTestCase{1, 2, true}, + CholeskyTestCase{1, 50, true}, + CholeskyTestCase{1, 50, false}, + CholeskyTestCase{10, 5, true}, + CholeskyTestCase{5, 10, false}, + CholeskyTestCase{2, 20, true})); } // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 93611740864..6e5772a7396 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -566,6 +566,12 @@ message TriangularSolveOptions { Transpose transpose_a = 4; } +message CholeskyOptions { + // If true, uses the lower triangle of `a`. If false, uses the upper triangle + // of `a`. + bool lower = 1; +} + message OpSharding { enum Type { // This sharding is replicated across all devices (implies maximal,