[XLA] Make Cholesky into a first-class HLO operator.

Currently it is expanded into an HLO implementation on all backends.

PiperOrigin-RevId: 235814360
This commit is contained in:
Peter Hawkins 2019-02-26 16:15:49 -08:00 committed by TensorFlower Gardener
parent 2ee3000734
commit cf0f741491
41 changed files with 535 additions and 162 deletions

View File

@ -136,7 +136,6 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:cholesky",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math",

View File

@ -15,7 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace tensorflow {
namespace {
@ -24,7 +25,9 @@ class CholeskyOp : public XlaOpKernel {
public:
explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
ctx->SetOutput(0, xla::Cholesky(ctx->Input(0)));
ctx->SetOutput(0,
xla::Triangle(xla::Cholesky(ctx->Input(0), /*lower=*/true),
/*lower=*/true));
}
};

View File

@ -49,47 +49,6 @@ xla_test(
],
)
cc_library(
name = "cholesky",
srcs = ["cholesky.cc"],
hdrs = ["cholesky.h"],
deps = [
":math",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/core:lib",
],
)
xla_test(
name = "cholesky_test",
srcs = ["cholesky_test.cc"],
tags = ["optonly"],
deps = [
":arithmetic",
":cholesky",
":matrix",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library(
name = "comparators",
srcs = ["comparators.cc"],

View File

@ -1,39 +0,0 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
// Computes the Cholesky decompositions of a batch of symmetric positive
// definite matrices.
// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
// two minor dimensions equal.
// The algorithm implements a blocked Cholesky decomposition; `block_size` is
// the block size to use.
// TODO(phawkins): check for negative values on the diagonal and return an
// error, instead of silently yielding NaNs.
// TODO(znado): handle the complex Hermitian case
xla::XlaOp Cholesky(
xla::XlaOp a, int64 block_size = 256,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_

View File

@ -3022,6 +3022,21 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
});
}
XlaOp Cholesky(XlaOp a, bool lower) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& a_shape, builder->GetShape(a));
xla::CholeskyOptions& options = *instr.mutable_cholesky_options();
options.set_lower(lower);
TF_ASSIGN_OR_RETURN(Shape shape,
ShapeInference::InferCholeskyShape(a_shape));
*instr.mutable_shape() = shape.ToProto();
return builder->AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
});
}
XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) {
return builder->Infeed(shape, config);
}

View File

@ -801,6 +801,7 @@ class XlaBuilder {
friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
bool unit_diagonal,
TriangularSolveOptions::Transpose transpose_a);
friend XlaOp Cholesky(XlaOp a, bool lower);
friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
const string& config);
friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
@ -1342,8 +1343,7 @@ XlaOp Fft(const XlaOp& operand, FftType fft_type,
// * `left_side` is a boolean, indicating whether to solve a system of the form
// op(a) * x = b (true) or x * op(a) = b (false).
// * `lower` is a boolean, indicating whether the argument `a` is
// lower-triangular
// (true) or upper-triangular (false).
// lower-triangular (true) or upper-triangular (false).
// * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
// 1 and not accessed.
// * `transpose_a` indicates which function `op` we use to transform the tensor
@ -1352,6 +1352,20 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
bool unit_diagonal,
TriangularSolveOptions::Transpose transpose_a);
// Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
// positive definite matrices.
// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
// two minor dimensions equal.
// If `lower` is true, the data from the lower triangle is used; if false, the
// upper triangle is used. The input data in the other triangle of the input
// does not affect the output. Returns the output in the same lower/uppper
// triangle. The data returned in the other output triangle is arbitrary and
// implementation-defined.
//
// The value returned if `a` is not Hermitian positive definite is
// implementation-defined.
XlaOp Cholesky(XlaOp a, bool lower);
// Enqueues an infeed instruction onto the computation, which writes data of
// the given shape to the infeed buffer of the device.
XlaOp Infeed(XlaBuilder* builder, const Shape& shape,

View File

@ -322,6 +322,37 @@ Invokes a computation with the given arguments.
The arity and types of the `args` must match the parameters of the
`computation`. It is allowed to have no `args`.
## Cholesky
See also
[`XlaBuilder::Cholesky`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Computes the
[Cholesky decomposition](https://en.wikipedia.org/wiki/Cholesky_decomposition)
of a batch of symmetric (Hermitian) positive definite matrices.
<b> `Cholesky(a, lower)` </b>
Arguments | Type | Semantics
--------- | ------- | -----------------------------------------------------
`a` | `XlaOp` | a rank > 2 array of a complex or floating-point type.
`lower` | `bool` | whether to use the upper or lower triangle of `a`.
If `lower` is `true`, computes lower-triangular matrices `l` such that $$ a = l
. l^T $$. If `lower` is `false`, computes upper-triangular matrices `u` such
that $$ a = u^T . u $$.
Input data is read only from the lower/upper triangle of `a`, depending on the
value of `lower`. Values from the other triangle are ignored. Output data is
returned in the same triangle; the values in the other triangle are
implementation-defined and may be anything.
If the rank of `a` is greater than 2, `a` is treated as a batch of matrices,
where all except the minor 2 dimensions are batch dimensions.
If `a` is not symmetric (Hermitian) positive definite, the result is
implementation-defined.
## Clamp
See also
@ -2439,6 +2470,46 @@ Permutes the operand dimensions with the given permutation, so
This is the same as Reshape(operand, permutation,
Permute(permutation, operand.shape.dimensions)).
## TriangularSolve
See also
[`XlaBuilder::TriangularSolve`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
Solves systems of linear equations with lower or upper triangular coefficient
matrices by forward- or back-substitution. Broadcasting along leading
dimensions, this routine solves one of the matrix systems `op(a) * x =
b`, or `x * op(a) = b`, for the variable `x`, given `a` and `b`, where `op(a)` is
either `op(a) = a`, or `op(a) = Transpose(a)`, or `op(a) = Conj(Transpose(a))`.
<b> `TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)` </b>
| Arguments | Type | Semantics |
| --------------- | ----------- | -------------------------------------------- |
| `a` | `XlaOp` | a rank > 2 array of a complex or |
: : : floating-point type with shape `[..., M, :
: : : M]`. :
| `b` | `XlaOp` | a rank > 2 array of the same type with shape |
: : : `[..., M, K]` if `left_side` is true, `[..., :
: : : K, M]` otherwise. :
| `left_side` | `bool` | indicates whether to solve a system of the |
: : : form `op(a) * x = b` (`true`) or `x * :
: : : op(a) = b` (`false`). :
| `lower` | `bool` | whether to use the upper or lower triangle |
: : : of `a`. :
| `unit_diagonal` | `bool` | if `true`, the diagonal elements of `a` are |
: : : assumed to be `1` and not accessed. :
| `transpose_a` | `Transpose` | whether to use `a` as is, transpose it or |
: : : take its conjugate transpose. :
Input data is read only from the lower/upper triangle of `a`, depending on the
value of `lower`. Values from the other triangle are ignored. Output data is
returned in the same triangle; the values in the other triangle are
implementation-defined and may be anything.
If the rank of `a` and `b` are greater than 2, they are treated as batches of
matrices, where all except the minor 2 dimensions are batch dimensions. `a` and
`b` must have equal batch dimensions.
## Tuple
See also

View File

@ -69,7 +69,6 @@ cc_library(
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:cholesky",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:qr",
"//tensorflow/compiler/xla/service:computation_placer",

View File

@ -21,7 +21,6 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
@ -714,8 +713,8 @@ LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys,
return xla::Sort(keys.op(), {values.op()}, dimension);
}
LocalOp ComputationBuilder::Cholesky(const LocalOp& a) {
return xla::Cholesky(a.op());
LocalOp ComputationBuilder::Cholesky(const LocalOp& a, bool lower) {
return xla::Cholesky(a.op(), lower);
}
LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) {

View File

@ -358,7 +358,7 @@ class ComputationBuilder {
LocalOp QR(const LocalOp& a, bool full_matrices);
LocalOp Cholesky(const LocalOp& a);
LocalOp Cholesky(const LocalOp& a, bool lower);
// `transpose_a` is the integer value of a TriangularSolveOptions::Transpose
// enum. We use an integer here so we don't have to teach SWIG about the

View File

@ -1735,9 +1735,9 @@ class ComputationBuilder(object):
"""Enqueues a key-value sort operation onto the computation."""
return self._client.SortKeyVal(keys, values, dimension)
def Cholesky(self, a):
def Cholesky(self, a, lower=True):
"""Enqueues a Cholesky decomposition onto the computation."""
return self._client.Cholesky(a)
return self._client.Cholesky(a, lower)
def QR(self, a, full_matrices=True):
"""Enqueues a QR decomposition onto the computation."""

View File

@ -1581,6 +1581,29 @@ cc_library(
],
)
cc_library(
name = "cholesky_expander",
srcs = ["cholesky_expander.cc"],
hdrs = ["cholesky_expander.h"],
deps = [
":op_expander_pass",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",
],
)
tf_cc_test(
name = "batchnorm_expander_test",
size = "small",

View File

@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
#include <memory>
#include <vector>
@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
namespace xla {
@ -134,10 +135,8 @@ XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) {
});
}
} // namespace
XlaOp Cholesky(XlaOp a, int64 block_size,
PrecisionConfig::Precision precision) {
XlaOp BuildCholesky(XlaOp a, int64 block_size,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
@ -206,4 +205,55 @@ XlaOp Cholesky(XlaOp a, int64 block_size,
});
}
} // namespace
bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCholesky;
}
StatusOr<HloInstruction*> CholeskyExpander::ExpandInstruction(
HloInstruction* instruction) {
const CholeskyOptions& options = instruction->cholesky_options();
const string name = absl::StrFormat(
"xla.cholesky_%s_%s", instruction->operand(0)->shape().ToString(),
options.lower() ? "lower" : "upper");
HloModule* module = instruction->parent()->parent();
HloComputation*& computation =
computation_cache_.emplace(name, nullptr).first->second;
if (!computation) {
// Builds a new expansion.
//
// TODO(b/62327888): We do something unusual here: we build the computation
// using the XlaBuilder API, which is nominally an XLA client API. We do
// this because the external APIs for building complicated computations
// (XlaBuilder) are much more ergonomic than the internal ones. As it turns
// out, XlaBuilder isn't really a client API—what it does is build a
// HloModuleProto protocol buffer, that we can then deserialize and clone
// into our HloModule. Ideally we would avoid the protocol buffer step;
// that is left as an exercise for future work.
XlaBuilder builder(name);
XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
XlaOp l = BuildCholesky(MaybeTransposeInMinorDims(a, !options.lower()),
/*block_size=*/128,
/*precision=*/PrecisionConfig::HIGHEST);
MaybeTransposeInMinorDims(l, !options.lower());
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloCloneContext context(module);
computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
}
return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
instruction->shape(), instruction->operands(), computation));
}
} // namespace xla

View File

@ -0,0 +1,41 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
namespace xla {
class CholeskyExpander : public OpExpanderPass {
public:
absl::string_view name() const override { return "cholesky_expander"; }
protected:
bool InstructionMatchesPattern(HloInstruction* instruction) override;
StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* instruction) override;
private:
// Mapping from op signatures to existing computations.
absl::flat_hash_map<string, HloComputation*> computation_cache_;
};
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_

View File

@ -111,6 +111,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:cholesky_expander",
"//tensorflow/compiler/xla/service:conditional_simplifier",
"//tensorflow/compiler/xla/service:convolution_group_converter",
"//tensorflow/compiler/xla/service:dot_decomposer",

View File

@ -50,6 +50,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
@ -257,6 +258,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
pipeline.AddPass<MapInliner>();
pipeline.AddPass<CholeskyExpander>();
pipeline.AddPass<TriangularSolveExpander>();
// TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner

View File

@ -112,6 +112,7 @@ class DfsHloVisitorBase {
virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
virtual Status HandleFft(HloInstructionPtr fft) = 0;
virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0;
virtual Status HandleCholesky(HloInstructionPtr hlo) = 0;
virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0;
virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;

View File

@ -94,6 +94,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleTriangularSolve(HloInstructionPtr hlo) override {
return DefaultAction(hlo);
}
Status HandleCholesky(HloInstructionPtr hlo) override {
return DefaultAction(hlo);
}
Status HandleAllReduce(HloInstructionPtr crs) override {
return DefaultAction(crs);
}

View File

@ -779,6 +779,7 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_assignment",
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:cholesky_expander",
"//tensorflow/compiler/xla/service:conditional_simplifier",
"//tensorflow/compiler/xla/service:convolution_group_converter",
"//tensorflow/compiler/xla/service:dot_decomposer",

View File

@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
@ -187,6 +188,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
&pipeline, hlo_module->config().debug_options(),
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
pipeline.AddPass<CholeskyExpander>();
// TODO(b/64094172): make Call work on GPU instead of inlining.
pipeline.AddPass<CallInliner>();
auto cost_model = [](HloInstruction* conv) {

View File

@ -29,12 +29,13 @@ limitations under the License.
syntax = "proto3";
package xla;
import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
// Next ID: 62
// Next ID: 63
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@ -200,6 +201,9 @@ message HloInstructionProto {
// Options for TriangularSolve
xla.TriangularSolveOptions triangular_solve_options = 59;
// Options for Cholesky
xla.CholeskyOptions cholesky_options = 62;
// Describes how parameters behave with regards to replicas.
xla.ParameterReplication parameter_replication = 61;
}

View File

@ -557,11 +557,22 @@ Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) {
// Estimate as batch * mn^2 / 2 flops.
int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
elems *= ShapeUtil::ElementsIn(b_shape);
// Each output elment requires reduction_widht FMA operations.
current_properties_[kFlopsKey] = kFmaFlops * elems;
return Status::OK();
}
Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) {
float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
current_properties_[kBytesAccessedKey] = bytes_accessed;
const Shape& a_shape = hlo->operand(0)->shape();
// Estimate as batch * n^3 / 3 flops.
int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
elems *= ShapeUtil::ElementsIn(a_shape);
current_properties_[kFlopsKey] = elems / 3;
return Status::OK();
}
Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) {
// We assume 2 replicas, so that each output element is the sum of two input
// elements.

View File

@ -72,6 +72,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleConvolution(const HloInstruction* convolution) override;
Status HandleFft(const HloInstruction* fft) override;
Status HandleTriangularSolve(const HloInstruction* hlo) override;
Status HandleCholesky(const HloInstruction* hlo) override;
Status HandleAllReduce(const HloInstruction* crs) override;
Status HandleAllToAll(const HloInstruction* hlo) override;
Status HandleCollectivePermute(const HloInstruction* hlo) override;

View File

@ -1017,6 +1017,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kDot:
case HloOpcode::kFft:
case HloOpcode::kTriangularSolve:
case HloOpcode::kCholesky:
return kDarkBlue;
case HloOpcode::kReducePrecision:
return kRed;

View File

@ -132,6 +132,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.triangular_solve_options());
break;
}
case HloOpcode::kCholesky: {
instruction =
CreateCholesky(shape, operands(0), proto.cholesky_options());
break;
}
case HloOpcode::kSend:
instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
proto.is_host_transfer());
@ -734,6 +739,11 @@ HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
return absl::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
return absl::make_unique<HloCholeskyInstruction>(shape, a, options);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers,
@ -1294,6 +1304,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kDomain:
case HloOpcode::kGetDimensionSize:
case HloOpcode::kTriangularSolve:
case HloOpcode::kCholesky:
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
break;
// Unary ops.
@ -1754,6 +1765,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kDomain:
case HloOpcode::kGetDimensionSize:
case HloOpcode::kTriangularSolve:
case HloOpcode::kCholesky:
LOG(FATAL) << "Base class impl called for opcode with subclass: "
<< opcode();
}
@ -2553,6 +2565,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleGetDimensionSize(this);
case HloOpcode::kTriangularSolve:
return visitor->HandleTriangularSolve(this);
case HloOpcode::kCholesky:
return visitor->HandleCholesky(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:
@ -3437,4 +3451,9 @@ const DomainMetadata& HloInstruction::user_side_metadata() const {
const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
}
const CholeskyOptions& HloInstruction::cholesky_options() const {
return Cast<HloCholeskyInstruction>(this)->cholesky_options();
}
} // namespace xla

View File

@ -448,6 +448,9 @@ class HloInstruction {
const Shape& shape, HloInstruction* a, HloInstruction* b,
const TriangularSolveOptions& options);
static std::unique_ptr<HloInstruction> CreateCholesky(
const Shape& shape, HloInstruction* a, const CholeskyOptions& options);
// Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
// dimensions specified in 'dimension_numbers'.
static std::unique_ptr<HloInstruction> CreateDot(
@ -1594,6 +1597,9 @@ class HloInstruction {
// Delegates to HloTriangularSolveInstruction::triangular_solve_options().
const TriangularSolveOptions& triangular_solve_options() const;
// Delegates to HloCholeskyInstruction::cholesky_options().
const CholeskyOptions& cholesky_options() const;
// Old methods kept for smooth subclassing transition END.
protected:

View File

@ -202,21 +202,6 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
fft_length_);
}
HloTriangularSolveInstruction::HloTriangularSolveInstruction(
const Shape& shape, HloInstruction* a, HloInstruction* b,
const TriangularSolveOptions& options)
: HloInstruction(HloOpcode::kTriangularSolve, shape),
triangular_solve_options_(options) {
AppendOperand(a);
AppendOperand(b);
}
HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_triangular_solve_options() = triangular_solve_options_;
return proto;
}
namespace {
// Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
@ -257,6 +242,21 @@ std::vector<string> AttributeProtoToStringVector(
} // namespace
HloTriangularSolveInstruction::HloTriangularSolveInstruction(
const Shape& shape, HloInstruction* a, HloInstruction* b,
const TriangularSolveOptions& options)
: HloInstruction(HloOpcode::kTriangularSolve, shape),
triangular_solve_options_(options) {
AppendOperand(a);
AppendOperand(b);
}
HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_triangular_solve_options() = triangular_solve_options_;
return proto;
}
std::vector<string> HloTriangularSolveInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return AttributeProtoToStringVector(triangular_solve_options_);
@ -286,6 +286,44 @@ HloTriangularSolveInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], new_operands[1], triangular_solve_options());
}
HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape,
HloInstruction* a,
const CholeskyOptions& options)
: HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) {
AppendOperand(a);
}
HloInstructionProto HloCholeskyInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
*proto.mutable_cholesky_options() = cholesky_options_;
return proto;
}
std::vector<string> HloCholeskyInstruction::ExtraAttributesToStringImpl(
const HloPrintOptions& options) const {
return AttributeProtoToStringVector(cholesky_options_);
}
bool HloCholeskyInstruction::IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const {
const auto& casted_other = static_cast<const HloCholeskyInstruction&>(other);
const auto& options = cholesky_options();
const auto& other_options = casted_other.cholesky_options();
return options.lower() == other_options.lower();
}
std::unique_ptr<HloInstruction>
HloCholeskyInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
CHECK_EQ(new_operands.size(), 1);
return absl::make_unique<HloCholeskyInstruction>(shape, new_operands[0],
cholesky_options());
}
HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
const Shape& shape,
int64 channel_id,

View File

@ -159,6 +159,31 @@ class HloTriangularSolveInstruction : public HloInstruction {
TriangularSolveOptions triangular_solve_options_;
};
class HloCholeskyInstruction : public HloInstruction {
public:
explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a,
const CholeskyOptions& options);
const CholeskyOptions& cholesky_options() const { return cholesky_options_; }
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
bool IdenticalSlowPath(
const HloInstruction& other,
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations) const override;
// Implementation for non-common logic of CloneWithNewOperands.
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
CholeskyOptions cholesky_options_;
};
class HloSendRecvInstruction : public HloInstruction {
public:
// Returns the channel id associated with the instruction. The id is

View File

@ -61,6 +61,7 @@ namespace xla {
V(kBroadcast, "broadcast", 1) \
V(kCall, "call", kHloOpcodeIsVariadic) \
V(kCeil, "ceil", 1) \
V(kCholesky, "cholesky", 1) \
V(kClamp, "clamp", 3) \
V(kCollectivePermute, "collective-permute", 1) \
V(kClz, "count-leading-zeros", 1) \

View File

@ -1129,6 +1129,17 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
shape, operands[0], operands[1], options));
break;
}
case HloOpcode::kCholesky: {
CholeskyOptions options;
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributesAsProtoMessage(
/*required_attrs=*/std::unordered_set<string>(), &options)) {
return false;
}
instruction = builder->AddInstruction(
HloInstruction::CreateCholesky(shape, operands[0], options));
break;
}
case HloOpcode::kBroadcast: {
optional<std::vector<int64>> broadcast_dimensions;
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,

View File

@ -177,6 +177,13 @@ Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) {
return CheckShape(hlo, expected);
}
Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1));
TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape(
hlo->operand(0)->shape()));
return CheckShape(hlo, expected);
}
Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) {
std::vector<const Shape*> operand_shapes;
for (const HloInstruction* operand : crs->operands()) {

View File

@ -52,6 +52,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleDot(HloInstruction* dot) override;
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleFft(HloInstruction* fft) override;
Status HandleCholesky(HloInstruction* hlo) override;
Status HandleTriangularSolve(HloInstruction* hlo) override;
Status HandleAllReduce(HloInstruction* crs) override;
Status HandleAllToAll(HloInstruction* hlo) override;

View File

@ -126,6 +126,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kBatchNormInference:
case HloOpcode::kBatchNormTraining:
case HloOpcode::kCall:
case HloOpcode::kCholesky:
case HloOpcode::kConditional:
case HloOpcode::kConvolution:
case HloOpcode::kAllReduce:

View File

@ -32,6 +32,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:algebraic_simplifier",
"//tensorflow/compiler/xla/service:cholesky_expander",
"//tensorflow/compiler/xla/service:compiler",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:dynamic_index_splitter",

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
@ -80,6 +81,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
HloPassPipeline pipeline("Interpreter");
pipeline.AddPass<DynamicIndexSplitter>();
pipeline.AddPass<CholeskyExpander>();
pipeline.AddPass<TriangularSolveExpander>();
pipeline.AddPass<LayoutAssignment>(
hlo_module->mutable_entry_computation_layout(),

View File

@ -2073,6 +2073,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kSubtract:
case HloOpcode::kTanh:
case HloOpcode::kTriangularSolve:
case HloOpcode::kCholesky:
case HloOpcode::kTupleSelect:
case HloOpcode::kWhile:
return false;

View File

@ -1911,6 +1911,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
/* static */ StatusOr<Shape> ShapeInference::InferTriangularSolveShape(
const Shape& a, const Shape& b, const TriangularSolveOptions& options) {
if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) ||
a.element_type() != b.element_type()) {
return InvalidArgument(
"Expected element types in shape to be floating or complex and "
"identical for TriangularSolve; got %s and %s.",
PrimitiveType_Name(a.element_type()),
PrimitiveType_Name(b.element_type()));
}
if (a.rank() < 2) {
return InvalidArgument(
"The 'a' argument to TriangularSolve must have rank >= 2, got shape %s",
@ -1952,6 +1960,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
return b;
}
/* static */ StatusOr<Shape> ShapeInference::InferCholeskyShape(
const Shape& a) {
if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) {
return InvalidArgument(
"Expected element type in shape to be floating or complex for "
"Cholesky; got %s.",
PrimitiveType_Name(a.element_type()));
}
if (a.rank() < 2) {
return InvalidArgument(
"The 'a' argument to Cholesky must have rank >= 2, got shape %s",
a.ToString());
}
if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
return InvalidArgument(
"The two minor dimensions of 'a' must have equal size, got %s.",
a.ToString());
}
return a;
}
/* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
absl::Span<const Shape* const> operand_shapes) {
for (const Shape* operand_shape : operand_shapes) {

View File

@ -120,6 +120,9 @@ class ShapeInference {
static StatusOr<Shape> InferTriangularSolveShape(
const Shape& a, const Shape& b, const TriangularSolveOptions& options);
// Infers the shape produced by the given triangular solve operation.
static StatusOr<Shape> InferCholeskyShape(const Shape& a);
// Infers the shape produced by a cross replica sum with the given operand
// shapes.
static StatusOr<Shape> InferAllReduceShape(

View File

@ -2217,3 +2217,23 @@ xla_test(
"//tensorflow/core:test",
],
)
xla_test(
name = "cholesky_test",
srcs = ["cholesky_test.cc"],
tags = ["optonly"],
deps = [
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)

View File

@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
#include <limits>
#include <memory>
#include <numeric>
#include <vector>
@ -32,27 +31,27 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace xla {
namespace {
using xla::int64;
using CholeskyTest = ClientLibraryTestBase;
using CholeskyTest = xla::ClientLibraryTestBase;
XLA_TEST_F(CholeskyTest, Lower) {
XlaBuilder builder(TestName());
XLA_TEST_F(CholeskyTest, Simple) {
xla::XlaBuilder builder(TestName());
xla::Array2D<float> a_vals({
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
float nan = std::numeric_limits<float>::quiet_NaN();
Array2D<float> a_vals({
{4, nan, nan, nan},
{6, 45, nan, nan},
{8, 54, 146, nan},
{10, 63, 166, 310},
});
xla::XlaOp a;
XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::Cholesky(a, /*block_size=*/2);
LowerTriangle(Cholesky(a, /*lower=*/true));
xla::Array2D<float> expected({
Array2D<float> expected({
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 9, 0},
@ -60,34 +59,62 @@ XLA_TEST_F(CholeskyTest, Simple) {
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(CholeskyTest, Upper) {
XlaBuilder builder(TestName());
float nan = std::numeric_limits<float>::quiet_NaN();
Array2D<float> a_vals({
{4, 6, 8, 10},
{nan, 45, 54, 63},
{nan, nan, 146, 166},
{nan, nan, nan, 310},
});
XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
UpperTriangle(Cholesky(a, /*lower=*/false));
Array2D<float> expected({
{2, 3, 4, 5},
{0, 6, 7, 8},
{0, 0, 9, 10},
{0, 0, 0, 11},
});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(CholeskyTest, Simple2) {
xla::XlaBuilder builder(TestName());
XlaBuilder builder(TestName());
xla::Array2D<float> a_vals({
Array2D<float> a_vals({
{16, 24, 8, 12},
{24, 61, 82, 48},
{8, 82, 456, 106},
{12, 48, 106, 62},
});
xla::XlaOp a;
XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::Cholesky(a);
LowerTriangle(Cholesky(a, /*lower=*/true));
xla::Array2D<float> expected(
{{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}});
Array2D<float> expected({{4, 0, 0, 0}, //
{6, 5, 0, 0}, //
{2, 14, 16, 0}, //
{3, 6, 1, 4}});
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(CholeskyTest, SimpleBatched) {
xla::XlaBuilder builder(TestName());
XlaBuilder builder(TestName());
xla::Array3D<float> a_vals({
Array3D<float> a_vals({
{
{4, 6, 8, 10},
{6, 45, 54, 63},
@ -102,65 +129,78 @@ XLA_TEST_F(CholeskyTest, SimpleBatched) {
},
});
xla::XlaOp a;
XlaOp a;
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
xla::Cholesky(a);
LowerTriangle(Cholesky(a, /*lower=*/true));
xla::Array3D<float> expected({
Array3D<float> expected({
{
{2, 0, 0, 0},
{3, 6, 0, 0},
{4, 7, 9, 0},
{5, 8, 10, 11},
},
{{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}},
{{4, 0, 0, 0}, //
{6, 5, 0, 0}, //
{2, 14, 16, 0}, //
{3, 6, 1, 4}},
});
ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
ErrorSpec(1e-4, 1e-4));
}
using CholeskyTestCase = std::tuple<int64, int64>;
using CholeskyTestCase = std::tuple<int64, int64, bool>;
class RandomCholeskyTest
: public xla::ClientLibraryTestBase,
: public ClientLibraryTestBase,
public ::testing::WithParamInterface<CholeskyTestCase> {};
XLA_TEST_P(RandomCholeskyTest, Random) {
xla::XlaBuilder builder(TestName());
XlaBuilder builder(TestName());
auto test_params = GetParam();
std::vector<int64> dimensions = {std::get<0>(test_params),
std::get<1>(test_params),
std::get<1>(test_params)};
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions);
bool lower = std::get<2>(test_params);
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
TF_ASSERT_OK_AND_ASSIGN(
auto literal,
xla::LiteralUtil::CreateRandomLiteral<xla::F32>(shape, 0.0, 1.0));
auto literal, LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
auto input = xla::Parameter(&builder, 0, shape, "input");
auto input = Parameter(&builder, 0, shape, "input");
// Form a random positive definite matrix.
auto matrix = xla::BatchDot(input, TransposeInMinorDims(input),
xla::PrecisionConfig::HIGHEST);
auto matrix =
BatchDot(input, TransposeInMinorDims(input), PrecisionConfig::HIGHEST);
auto cholesky = xla::Cholesky(matrix, /*block_size=*/4);
auto cholesky = Triangle(Cholesky(matrix, lower), lower);
// Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky),
xla::PrecisionConfig::HIGHEST);
XlaOp verification;
if (lower) {
verification = BatchDot(cholesky, TransposeInMinorDims(cholesky),
PrecisionConfig::HIGHEST);
} else {
verification = BatchDot(TransposeInMinorDims(cholesky), cholesky,
PrecisionConfig::HIGHEST);
}
auto delta = matrix - verification;
xla::Reduce(delta * delta, xla::ConstantR0<float>(&builder, 0.0),
CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2});
Reduce(delta * delta, ConstantR0<float>(&builder, 0.0),
CreateScalarAddComputation(F32, &builder), {0, 1, 2});
TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
ComputeAndCompareR0<float>(&builder, 0.0, {input_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
ErrorSpec(1e-4, 1e-4));
}
INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
::testing::Values(CholeskyTestCase{1, 1},
CholeskyTestCase{1, 2},
CholeskyTestCase{10, 5},
CholeskyTestCase{2, 20}));
::testing::Values(CholeskyTestCase{1, 1, true},
CholeskyTestCase{1, 2, true},
CholeskyTestCase{1, 50, true},
CholeskyTestCase{1, 50, false},
CholeskyTestCase{10, 5, true},
CholeskyTestCase{5, 10, false},
CholeskyTestCase{2, 20, true}));
} // namespace
} // namespace xla

View File

@ -566,6 +566,12 @@ message TriangularSolveOptions {
Transpose transpose_a = 4;
}
message CholeskyOptions {
// If true, uses the lower triangle of `a`. If false, uses the upper triangle
// of `a`.
bool lower = 1;
}
message OpSharding {
enum Type {
// This sharding is replicated across all devices (implies maximal,