[XLA] Make Cholesky into a first-class HLO operator.
Currently it is expanded into an HLO implementation on all backends. PiperOrigin-RevId: 235814360
This commit is contained in:
parent
2ee3000734
commit
cf0f741491
@ -136,7 +136,6 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/client/lib:cholesky",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:loops",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
|
@ -15,7 +15,8 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -24,7 +25,9 @@ class CholeskyOp : public XlaOpKernel {
|
||||
public:
|
||||
explicit CholeskyOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
ctx->SetOutput(0, xla::Cholesky(ctx->Input(0)));
|
||||
ctx->SetOutput(0,
|
||||
xla::Triangle(xla::Cholesky(ctx->Input(0), /*lower=*/true),
|
||||
/*lower=*/true));
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -49,47 +49,6 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cholesky",
|
||||
srcs = ["cholesky.cc"],
|
||||
hdrs = ["cholesky.h"],
|
||||
deps = [
|
||||
":math",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:loops",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "cholesky_test",
|
||||
srcs = ["cholesky_test.cc"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":arithmetic",
|
||||
":cholesky",
|
||||
":matrix",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "comparators",
|
||||
srcs = ["comparators.cc"],
|
||||
|
@ -1,39 +0,0 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Computes the Cholesky decompositions of a batch of symmetric positive
|
||||
// definite matrices.
|
||||
// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
|
||||
// two minor dimensions equal.
|
||||
// The algorithm implements a blocked Cholesky decomposition; `block_size` is
|
||||
// the block size to use.
|
||||
// TODO(phawkins): check for negative values on the diagonal and return an
|
||||
// error, instead of silently yielding NaNs.
|
||||
// TODO(znado): handle the complex Hermitian case
|
||||
xla::XlaOp Cholesky(
|
||||
xla::XlaOp a, int64 block_size = 256,
|
||||
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CHOLESKY_H_
|
@ -3022,6 +3022,21 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp Cholesky(XlaOp a, bool lower) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
HloInstructionProto instr;
|
||||
TF_ASSIGN_OR_RETURN(const Shape& a_shape, builder->GetShape(a));
|
||||
xla::CholeskyOptions& options = *instr.mutable_cholesky_options();
|
||||
options.set_lower(lower);
|
||||
TF_ASSIGN_OR_RETURN(Shape shape,
|
||||
ShapeInference::InferCholeskyShape(a_shape));
|
||||
*instr.mutable_shape() = shape.ToProto();
|
||||
|
||||
return builder->AddInstruction(std::move(instr), HloOpcode::kCholesky, {a});
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp Infeed(XlaBuilder* builder, const Shape& shape, const string& config) {
|
||||
return builder->Infeed(shape, config);
|
||||
}
|
||||
|
@ -801,6 +801,7 @@ class XlaBuilder {
|
||||
friend XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
||||
bool unit_diagonal,
|
||||
TriangularSolveOptions::Transpose transpose_a);
|
||||
friend XlaOp Cholesky(XlaOp a, bool lower);
|
||||
friend XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
|
||||
const string& config);
|
||||
friend void Outfeed(const XlaOp& operand, const Shape& shape_with_layout,
|
||||
@ -1342,8 +1343,7 @@ XlaOp Fft(const XlaOp& operand, FftType fft_type,
|
||||
// * `left_side` is a boolean, indicating whether to solve a system of the form
|
||||
// op(a) * x = b (true) or x * op(a) = b (false).
|
||||
// * `lower` is a boolean, indicating whether the argument `a` is
|
||||
// lower-triangular
|
||||
// (true) or upper-triangular (false).
|
||||
// lower-triangular (true) or upper-triangular (false).
|
||||
// * If `unit_diagonal` is true, the diagonal elements of `a` are assumed to be
|
||||
// 1 and not accessed.
|
||||
// * `transpose_a` indicates which function `op` we use to transform the tensor
|
||||
@ -1352,6 +1352,20 @@ XlaOp TriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
||||
bool unit_diagonal,
|
||||
TriangularSolveOptions::Transpose transpose_a);
|
||||
|
||||
// Computes the Cholesky decompositions of a batch of symmetric (Hermitian)
|
||||
// positive definite matrices.
|
||||
// `a` must be a (batched) square matrix; i.e., it must have rank >= 2 with the
|
||||
// two minor dimensions equal.
|
||||
// If `lower` is true, the data from the lower triangle is used; if false, the
|
||||
// upper triangle is used. The input data in the other triangle of the input
|
||||
// does not affect the output. Returns the output in the same lower/uppper
|
||||
// triangle. The data returned in the other output triangle is arbitrary and
|
||||
// implementation-defined.
|
||||
//
|
||||
// The value returned if `a` is not Hermitian positive definite is
|
||||
// implementation-defined.
|
||||
XlaOp Cholesky(XlaOp a, bool lower);
|
||||
|
||||
// Enqueues an infeed instruction onto the computation, which writes data of
|
||||
// the given shape to the infeed buffer of the device.
|
||||
XlaOp Infeed(XlaBuilder* builder, const Shape& shape,
|
||||
|
@ -322,6 +322,37 @@ Invokes a computation with the given arguments.
|
||||
The arity and types of the `args` must match the parameters of the
|
||||
`computation`. It is allowed to have no `args`.
|
||||
|
||||
## Cholesky
|
||||
|
||||
See also
|
||||
[`XlaBuilder::Cholesky`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
|
||||
|
||||
Computes the
|
||||
[Cholesky decomposition](https://en.wikipedia.org/wiki/Cholesky_decomposition)
|
||||
of a batch of symmetric (Hermitian) positive definite matrices.
|
||||
|
||||
<b> `Cholesky(a, lower)` </b>
|
||||
|
||||
Arguments | Type | Semantics
|
||||
--------- | ------- | -----------------------------------------------------
|
||||
`a` | `XlaOp` | a rank > 2 array of a complex or floating-point type.
|
||||
`lower` | `bool` | whether to use the upper or lower triangle of `a`.
|
||||
|
||||
If `lower` is `true`, computes lower-triangular matrices `l` such that $$ a = l
|
||||
. l^T $$. If `lower` is `false`, computes upper-triangular matrices `u` such
|
||||
that $$ a = u^T . u $$.
|
||||
|
||||
Input data is read only from the lower/upper triangle of `a`, depending on the
|
||||
value of `lower`. Values from the other triangle are ignored. Output data is
|
||||
returned in the same triangle; the values in the other triangle are
|
||||
implementation-defined and may be anything.
|
||||
|
||||
If the rank of `a` is greater than 2, `a` is treated as a batch of matrices,
|
||||
where all except the minor 2 dimensions are batch dimensions.
|
||||
|
||||
If `a` is not symmetric (Hermitian) positive definite, the result is
|
||||
implementation-defined.
|
||||
|
||||
## Clamp
|
||||
|
||||
See also
|
||||
@ -2439,6 +2470,46 @@ Permutes the operand dimensions with the given permutation, so
|
||||
This is the same as Reshape(operand, permutation,
|
||||
Permute(permutation, operand.shape.dimensions)).
|
||||
|
||||
## TriangularSolve
|
||||
|
||||
See also
|
||||
[`XlaBuilder::TriangularSolve`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
|
||||
|
||||
Solves systems of linear equations with lower or upper triangular coefficient
|
||||
matrices by forward- or back-substitution. Broadcasting along leading
|
||||
dimensions, this routine solves one of the matrix systems `op(a) * x =
|
||||
b`, or `x * op(a) = b`, for the variable `x`, given `a` and `b`, where `op(a)` is
|
||||
either `op(a) = a`, or `op(a) = Transpose(a)`, or `op(a) = Conj(Transpose(a))`.
|
||||
|
||||
<b> `TriangularSolve(a, b, left_side, lower, unit_diagonal, transpose_a)` </b>
|
||||
|
||||
| Arguments | Type | Semantics |
|
||||
| --------------- | ----------- | -------------------------------------------- |
|
||||
| `a` | `XlaOp` | a rank > 2 array of a complex or |
|
||||
: : : floating-point type with shape `[..., M, :
|
||||
: : : M]`. :
|
||||
| `b` | `XlaOp` | a rank > 2 array of the same type with shape |
|
||||
: : : `[..., M, K]` if `left_side` is true, `[..., :
|
||||
: : : K, M]` otherwise. :
|
||||
| `left_side` | `bool` | indicates whether to solve a system of the |
|
||||
: : : form `op(a) * x = b` (`true`) or `x * :
|
||||
: : : op(a) = b` (`false`). :
|
||||
| `lower` | `bool` | whether to use the upper or lower triangle |
|
||||
: : : of `a`. :
|
||||
| `unit_diagonal` | `bool` | if `true`, the diagonal elements of `a` are |
|
||||
: : : assumed to be `1` and not accessed. :
|
||||
| `transpose_a` | `Transpose` | whether to use `a` as is, transpose it or |
|
||||
: : : take its conjugate transpose. :
|
||||
|
||||
Input data is read only from the lower/upper triangle of `a`, depending on the
|
||||
value of `lower`. Values from the other triangle are ignored. Output data is
|
||||
returned in the same triangle; the values in the other triangle are
|
||||
implementation-defined and may be anything.
|
||||
|
||||
If the rank of `a` and `b` are greater than 2, they are treated as batches of
|
||||
matrices, where all except the minor 2 dimensions are batch dimensions. `a` and
|
||||
`b` must have equal batch dimensions.
|
||||
|
||||
## Tuple
|
||||
|
||||
See also
|
||||
|
@ -69,7 +69,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:cholesky",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:qr",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
@ -714,8 +713,8 @@ LocalOp ComputationBuilder::SortKeyVal(const LocalOp& keys,
|
||||
return xla::Sort(keys.op(), {values.op()}, dimension);
|
||||
}
|
||||
|
||||
LocalOp ComputationBuilder::Cholesky(const LocalOp& a) {
|
||||
return xla::Cholesky(a.op());
|
||||
LocalOp ComputationBuilder::Cholesky(const LocalOp& a, bool lower) {
|
||||
return xla::Cholesky(a.op(), lower);
|
||||
}
|
||||
|
||||
LocalOp ComputationBuilder::QR(const LocalOp& a, bool full_matrices) {
|
||||
|
@ -358,7 +358,7 @@ class ComputationBuilder {
|
||||
|
||||
LocalOp QR(const LocalOp& a, bool full_matrices);
|
||||
|
||||
LocalOp Cholesky(const LocalOp& a);
|
||||
LocalOp Cholesky(const LocalOp& a, bool lower);
|
||||
|
||||
// `transpose_a` is the integer value of a TriangularSolveOptions::Transpose
|
||||
// enum. We use an integer here so we don't have to teach SWIG about the
|
||||
|
@ -1735,9 +1735,9 @@ class ComputationBuilder(object):
|
||||
"""Enqueues a key-value sort operation onto the computation."""
|
||||
return self._client.SortKeyVal(keys, values, dimension)
|
||||
|
||||
def Cholesky(self, a):
|
||||
def Cholesky(self, a, lower=True):
|
||||
"""Enqueues a Cholesky decomposition onto the computation."""
|
||||
return self._client.Cholesky(a)
|
||||
return self._client.Cholesky(a, lower)
|
||||
|
||||
def QR(self, a, full_matrices=True):
|
||||
"""Enqueues a QR decomposition onto the computation."""
|
||||
|
@ -1581,6 +1581,29 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cholesky_expander",
|
||||
srcs = ["cholesky_expander.cc"],
|
||||
hdrs = ["cholesky_expander.h"],
|
||||
deps = [
|
||||
":op_expander_pass",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:loops",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "batchnorm_expander_test",
|
||||
size = "small",
|
||||
|
@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
|
||||
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
@ -29,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
@ -134,10 +135,8 @@ XlaOp CholeskyUnblocked(XlaOp a, PrecisionConfig::Precision precision) {
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
XlaOp Cholesky(XlaOp a, int64 block_size,
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaOp BuildCholesky(XlaOp a, int64 block_size,
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
@ -206,4 +205,55 @@ XlaOp Cholesky(XlaOp a, int64 block_size,
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool CholeskyExpander::InstructionMatchesPattern(HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCholesky;
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> CholeskyExpander::ExpandInstruction(
|
||||
HloInstruction* instruction) {
|
||||
const CholeskyOptions& options = instruction->cholesky_options();
|
||||
const string name = absl::StrFormat(
|
||||
"xla.cholesky_%s_%s", instruction->operand(0)->shape().ToString(),
|
||||
options.lower() ? "lower" : "upper");
|
||||
|
||||
HloModule* module = instruction->parent()->parent();
|
||||
|
||||
HloComputation*& computation =
|
||||
computation_cache_.emplace(name, nullptr).first->second;
|
||||
if (!computation) {
|
||||
// Builds a new expansion.
|
||||
//
|
||||
// TODO(b/62327888): We do something unusual here: we build the computation
|
||||
// using the XlaBuilder API, which is nominally an XLA client API. We do
|
||||
// this because the external APIs for building complicated computations
|
||||
// (XlaBuilder) are much more ergonomic than the internal ones. As it turns
|
||||
// out, XlaBuilder isn't really a client API—what it does is build a
|
||||
// HloModuleProto protocol buffer, that we can then deserialize and clone
|
||||
// into our HloModule. Ideally we would avoid the protocol buffer step;
|
||||
// that is left as an exercise for future work.
|
||||
XlaBuilder builder(name);
|
||||
XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
|
||||
XlaOp l = BuildCholesky(MaybeTransposeInMinorDims(a, !options.lower()),
|
||||
/*block_size=*/128,
|
||||
/*precision=*/PrecisionConfig::HIGHEST);
|
||||
MaybeTransposeInMinorDims(l, !options.lower());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
xla_computation.GetProgramShape());
|
||||
HloModuleConfig config(program_shape);
|
||||
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
|
||||
xla_computation.proto(), config));
|
||||
HloCloneContext context(module);
|
||||
computation =
|
||||
module->DeepCloneComputation(new_module->entry_computation(), &context);
|
||||
}
|
||||
|
||||
return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
|
||||
instruction->shape(), instruction->operands(), computation));
|
||||
}
|
||||
|
||||
} // namespace xla
|
41
tensorflow/compiler/xla/service/cholesky_expander.h
Normal file
41
tensorflow/compiler/xla/service/cholesky_expander.h
Normal file
@ -0,0 +1,41 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
class CholeskyExpander : public OpExpanderPass {
|
||||
public:
|
||||
absl::string_view name() const override { return "cholesky_expander"; }
|
||||
|
||||
protected:
|
||||
bool InstructionMatchesPattern(HloInstruction* instruction) override;
|
||||
|
||||
StatusOr<HloInstruction*> ExpandInstruction(
|
||||
HloInstruction* instruction) override;
|
||||
|
||||
private:
|
||||
// Mapping from op signatures to existing computations.
|
||||
absl::flat_hash_map<string, HloComputation*> computation_cache_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CHOLESKY_EXPANDER_H_
|
@ -111,6 +111,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:buffer_liveness",
|
||||
"//tensorflow/compiler/xla/service:call_inliner",
|
||||
"//tensorflow/compiler/xla/service:cholesky_expander",
|
||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||
"//tensorflow/compiler/xla/service:convolution_group_converter",
|
||||
"//tensorflow/compiler/xla/service:dot_decomposer",
|
||||
|
@ -50,6 +50,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/buffer_info_util.h"
|
||||
@ -257,6 +258,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn(
|
||||
|
||||
pipeline.AddPass<MapInliner>();
|
||||
|
||||
pipeline.AddPass<CholeskyExpander>();
|
||||
pipeline.AddPass<TriangularSolveExpander>();
|
||||
|
||||
// TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner
|
||||
|
@ -112,6 +112,7 @@ class DfsHloVisitorBase {
|
||||
virtual Status HandleConvolution(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleFft(HloInstructionPtr fft) = 0;
|
||||
virtual Status HandleTriangularSolve(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleCholesky(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleAllReduce(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleAllToAll(HloInstructionPtr hlo) = 0;
|
||||
virtual Status HandleCollectivePermute(HloInstructionPtr hlo) = 0;
|
||||
|
@ -94,6 +94,9 @@ class DfsHloVisitorWithDefaultBase
|
||||
Status HandleTriangularSolve(HloInstructionPtr hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
Status HandleCholesky(HloInstructionPtr hlo) override {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
Status HandleAllReduce(HloInstructionPtr crs) override {
|
||||
return DefaultAction(crs);
|
||||
}
|
||||
|
@ -779,6 +779,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:buffer_liveness",
|
||||
"//tensorflow/compiler/xla/service:call_inliner",
|
||||
"//tensorflow/compiler/xla/service:cholesky_expander",
|
||||
"//tensorflow/compiler/xla/service:conditional_simplifier",
|
||||
"//tensorflow/compiler/xla/service:convolution_group_converter",
|
||||
"//tensorflow/compiler/xla/service:dot_decomposer",
|
||||
|
@ -35,6 +35,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
|
||||
#include "tensorflow/compiler/xla/service/call_inliner.h"
|
||||
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/convolution_group_converter.h"
|
||||
#include "tensorflow/compiler/xla/service/dot_decomposer.h"
|
||||
@ -187,6 +188,8 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
|
||||
&pipeline, hlo_module->config().debug_options(),
|
||||
ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION);
|
||||
|
||||
pipeline.AddPass<CholeskyExpander>();
|
||||
|
||||
// TODO(b/64094172): make Call work on GPU instead of inlining.
|
||||
pipeline.AddPass<CallInliner>();
|
||||
auto cost_model = [](HloInstruction* conv) {
|
||||
|
@ -29,12 +29,13 @@ limitations under the License.
|
||||
syntax = "proto3";
|
||||
|
||||
package xla;
|
||||
|
||||
import "tensorflow/compiler/xla/xla_data.proto";
|
||||
|
||||
option cc_enable_arenas = true;
|
||||
|
||||
// Serialization of HloInstruction.
|
||||
// Next ID: 62
|
||||
// Next ID: 63
|
||||
message HloInstructionProto {
|
||||
reserved 10;
|
||||
reserved "parameter_name";
|
||||
@ -200,6 +201,9 @@ message HloInstructionProto {
|
||||
// Options for TriangularSolve
|
||||
xla.TriangularSolveOptions triangular_solve_options = 59;
|
||||
|
||||
// Options for Cholesky
|
||||
xla.CholeskyOptions cholesky_options = 62;
|
||||
|
||||
// Describes how parameters behave with regards to replicas.
|
||||
xla.ParameterReplication parameter_replication = 61;
|
||||
}
|
||||
|
@ -557,11 +557,22 @@ Status HloCostAnalysis::HandleTriangularSolve(const HloInstruction* hlo) {
|
||||
// Estimate as batch * mn^2 / 2 flops.
|
||||
int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
|
||||
elems *= ShapeUtil::ElementsIn(b_shape);
|
||||
// Each output elment requires reduction_widht FMA operations.
|
||||
current_properties_[kFlopsKey] = kFmaFlops * elems;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleCholesky(const HloInstruction* hlo) {
|
||||
float bytes_accessed = GetShapeSize(hlo->operand(0)->shape()) / 2.0f;
|
||||
current_properties_[kBytesAccessedKey] = bytes_accessed;
|
||||
|
||||
const Shape& a_shape = hlo->operand(0)->shape();
|
||||
// Estimate as batch * n^3 / 3 flops.
|
||||
int64 elems = a_shape.dimensions(a_shape.dimensions_size() - 1);
|
||||
elems *= ShapeUtil::ElementsIn(a_shape);
|
||||
current_properties_[kFlopsKey] = elems / 3;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status HloCostAnalysis::HandleAllReduce(const HloInstruction* crs) {
|
||||
// We assume 2 replicas, so that each output element is the sum of two input
|
||||
// elements.
|
||||
|
@ -72,6 +72,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
|
||||
Status HandleConvolution(const HloInstruction* convolution) override;
|
||||
Status HandleFft(const HloInstruction* fft) override;
|
||||
Status HandleTriangularSolve(const HloInstruction* hlo) override;
|
||||
Status HandleCholesky(const HloInstruction* hlo) override;
|
||||
Status HandleAllReduce(const HloInstruction* crs) override;
|
||||
Status HandleAllToAll(const HloInstruction* hlo) override;
|
||||
Status HandleCollectivePermute(const HloInstruction* hlo) override;
|
||||
|
@ -1017,6 +1017,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
|
||||
case HloOpcode::kDot:
|
||||
case HloOpcode::kFft:
|
||||
case HloOpcode::kTriangularSolve:
|
||||
case HloOpcode::kCholesky:
|
||||
return kDarkBlue;
|
||||
case HloOpcode::kReducePrecision:
|
||||
return kRed;
|
||||
|
@ -132,6 +132,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
|
||||
proto.triangular_solve_options());
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCholesky: {
|
||||
instruction =
|
||||
CreateCholesky(shape, operands(0), proto.cholesky_options());
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kSend:
|
||||
instruction = CreateSend(operands(0), operands(1), proto.channel_id(),
|
||||
proto.is_host_transfer());
|
||||
@ -734,6 +739,11 @@ HloInstruction::CreateTriangularSolve(const Shape& shape, HloInstruction* a,
|
||||
return absl::make_unique<HloTriangularSolveInstruction>(shape, a, b, options);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCholesky(
|
||||
const Shape& shape, HloInstruction* a, const CholeskyOptions& options) {
|
||||
return absl::make_unique<HloCholeskyInstruction>(shape, a, options);
|
||||
}
|
||||
|
||||
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
|
||||
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
|
||||
const DotDimensionNumbers& dimension_numbers,
|
||||
@ -1294,6 +1304,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
|
||||
case HloOpcode::kDomain:
|
||||
case HloOpcode::kGetDimensionSize:
|
||||
case HloOpcode::kTriangularSolve:
|
||||
case HloOpcode::kCholesky:
|
||||
clone = CloneWithNewOperandsImpl(shape, new_operands, context);
|
||||
break;
|
||||
// Unary ops.
|
||||
@ -1754,6 +1765,7 @@ bool HloInstruction::IdenticalSlowPath(
|
||||
case HloOpcode::kDomain:
|
||||
case HloOpcode::kGetDimensionSize:
|
||||
case HloOpcode::kTriangularSolve:
|
||||
case HloOpcode::kCholesky:
|
||||
LOG(FATAL) << "Base class impl called for opcode with subclass: "
|
||||
<< opcode();
|
||||
}
|
||||
@ -2553,6 +2565,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
|
||||
return visitor->HandleGetDimensionSize(this);
|
||||
case HloOpcode::kTriangularSolve:
|
||||
return visitor->HandleTriangularSolve(this);
|
||||
case HloOpcode::kCholesky:
|
||||
return visitor->HandleCholesky(this);
|
||||
|
||||
// These opcodes are not handled here.
|
||||
case HloOpcode::kTrace:
|
||||
@ -3437,4 +3451,9 @@ const DomainMetadata& HloInstruction::user_side_metadata() const {
|
||||
const TriangularSolveOptions& HloInstruction::triangular_solve_options() const {
|
||||
return Cast<HloTriangularSolveInstruction>(this)->triangular_solve_options();
|
||||
}
|
||||
|
||||
const CholeskyOptions& HloInstruction::cholesky_options() const {
|
||||
return Cast<HloCholeskyInstruction>(this)->cholesky_options();
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -448,6 +448,9 @@ class HloInstruction {
|
||||
const Shape& shape, HloInstruction* a, HloInstruction* b,
|
||||
const TriangularSolveOptions& options);
|
||||
|
||||
static std::unique_ptr<HloInstruction> CreateCholesky(
|
||||
const Shape& shape, HloInstruction* a, const CholeskyOptions& options);
|
||||
|
||||
// Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
|
||||
// dimensions specified in 'dimension_numbers'.
|
||||
static std::unique_ptr<HloInstruction> CreateDot(
|
||||
@ -1594,6 +1597,9 @@ class HloInstruction {
|
||||
// Delegates to HloTriangularSolveInstruction::triangular_solve_options().
|
||||
const TriangularSolveOptions& triangular_solve_options() const;
|
||||
|
||||
// Delegates to HloCholeskyInstruction::cholesky_options().
|
||||
const CholeskyOptions& cholesky_options() const;
|
||||
|
||||
// Old methods kept for smooth subclassing transition END.
|
||||
|
||||
protected:
|
||||
|
@ -202,21 +202,6 @@ std::unique_ptr<HloInstruction> HloFftInstruction::CloneWithNewOperandsImpl(
|
||||
fft_length_);
|
||||
}
|
||||
|
||||
HloTriangularSolveInstruction::HloTriangularSolveInstruction(
|
||||
const Shape& shape, HloInstruction* a, HloInstruction* b,
|
||||
const TriangularSolveOptions& options)
|
||||
: HloInstruction(HloOpcode::kTriangularSolve, shape),
|
||||
triangular_solve_options_(options) {
|
||||
AppendOperand(a);
|
||||
AppendOperand(b);
|
||||
}
|
||||
|
||||
HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
*proto.mutable_triangular_solve_options() = triangular_solve_options_;
|
||||
return proto;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Converts a protocol buffer message (e.g., TriangularSolveOptions) to a vector
|
||||
@ -257,6 +242,21 @@ std::vector<string> AttributeProtoToStringVector(
|
||||
|
||||
} // namespace
|
||||
|
||||
HloTriangularSolveInstruction::HloTriangularSolveInstruction(
|
||||
const Shape& shape, HloInstruction* a, HloInstruction* b,
|
||||
const TriangularSolveOptions& options)
|
||||
: HloInstruction(HloOpcode::kTriangularSolve, shape),
|
||||
triangular_solve_options_(options) {
|
||||
AppendOperand(a);
|
||||
AppendOperand(b);
|
||||
}
|
||||
|
||||
HloInstructionProto HloTriangularSolveInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
*proto.mutable_triangular_solve_options() = triangular_solve_options_;
|
||||
return proto;
|
||||
}
|
||||
|
||||
std::vector<string> HloTriangularSolveInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const {
|
||||
return AttributeProtoToStringVector(triangular_solve_options_);
|
||||
@ -286,6 +286,44 @@ HloTriangularSolveInstruction::CloneWithNewOperandsImpl(
|
||||
shape, new_operands[0], new_operands[1], triangular_solve_options());
|
||||
}
|
||||
|
||||
HloCholeskyInstruction::HloCholeskyInstruction(const Shape& shape,
|
||||
HloInstruction* a,
|
||||
const CholeskyOptions& options)
|
||||
: HloInstruction(HloOpcode::kCholesky, shape), cholesky_options_(options) {
|
||||
AppendOperand(a);
|
||||
}
|
||||
|
||||
HloInstructionProto HloCholeskyInstruction::ToProto() const {
|
||||
HloInstructionProto proto = HloInstruction::ToProto();
|
||||
*proto.mutable_cholesky_options() = cholesky_options_;
|
||||
return proto;
|
||||
}
|
||||
|
||||
std::vector<string> HloCholeskyInstruction::ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const {
|
||||
return AttributeProtoToStringVector(cholesky_options_);
|
||||
}
|
||||
|
||||
bool HloCholeskyInstruction::IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const {
|
||||
const auto& casted_other = static_cast<const HloCholeskyInstruction&>(other);
|
||||
const auto& options = cholesky_options();
|
||||
const auto& other_options = casted_other.cholesky_options();
|
||||
|
||||
return options.lower() == other_options.lower();
|
||||
}
|
||||
|
||||
std::unique_ptr<HloInstruction>
|
||||
HloCholeskyInstruction::CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const {
|
||||
CHECK_EQ(new_operands.size(), 1);
|
||||
return absl::make_unique<HloCholeskyInstruction>(shape, new_operands[0],
|
||||
cholesky_options());
|
||||
}
|
||||
|
||||
HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
|
||||
const Shape& shape,
|
||||
int64 channel_id,
|
||||
|
@ -159,6 +159,31 @@ class HloTriangularSolveInstruction : public HloInstruction {
|
||||
TriangularSolveOptions triangular_solve_options_;
|
||||
};
|
||||
|
||||
class HloCholeskyInstruction : public HloInstruction {
|
||||
public:
|
||||
explicit HloCholeskyInstruction(const Shape& shape, HloInstruction* a,
|
||||
const CholeskyOptions& options);
|
||||
const CholeskyOptions& cholesky_options() const { return cholesky_options_; }
|
||||
|
||||
// Returns a serialized representation of this instruction.
|
||||
HloInstructionProto ToProto() const override;
|
||||
|
||||
private:
|
||||
std::vector<string> ExtraAttributesToStringImpl(
|
||||
const HloPrintOptions& options) const override;
|
||||
bool IdenticalSlowPath(
|
||||
const HloInstruction& other,
|
||||
const std::function<bool(const HloComputation*, const HloComputation*)>&
|
||||
eq_computations) const override;
|
||||
|
||||
// Implementation for non-common logic of CloneWithNewOperands.
|
||||
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
|
||||
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
|
||||
HloCloneContext* context) const override;
|
||||
|
||||
CholeskyOptions cholesky_options_;
|
||||
};
|
||||
|
||||
class HloSendRecvInstruction : public HloInstruction {
|
||||
public:
|
||||
// Returns the channel id associated with the instruction. The id is
|
||||
|
@ -61,6 +61,7 @@ namespace xla {
|
||||
V(kBroadcast, "broadcast", 1) \
|
||||
V(kCall, "call", kHloOpcodeIsVariadic) \
|
||||
V(kCeil, "ceil", 1) \
|
||||
V(kCholesky, "cholesky", 1) \
|
||||
V(kClamp, "clamp", 3) \
|
||||
V(kCollectivePermute, "collective-permute", 1) \
|
||||
V(kClz, "count-leading-zeros", 1) \
|
||||
|
@ -1129,6 +1129,17 @@ bool HloParser::ParseInstructionRhs(HloComputation::Builder* builder,
|
||||
shape, operands[0], operands[1], options));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kCholesky: {
|
||||
CholeskyOptions options;
|
||||
if (!ParseOperands(&operands, /*expected_size=*/1) ||
|
||||
!ParseAttributesAsProtoMessage(
|
||||
/*required_attrs=*/std::unordered_set<string>(), &options)) {
|
||||
return false;
|
||||
}
|
||||
instruction = builder->AddInstruction(
|
||||
HloInstruction::CreateCholesky(shape, operands[0], options));
|
||||
break;
|
||||
}
|
||||
case HloOpcode::kBroadcast: {
|
||||
optional<std::vector<int64>> broadcast_dimensions;
|
||||
attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
|
||||
|
@ -177,6 +177,13 @@ Status ShapeVerifier::HandleTriangularSolve(HloInstruction* hlo) {
|
||||
return CheckShape(hlo, expected);
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleCholesky(HloInstruction* hlo) {
|
||||
TF_RETURN_IF_ERROR(CheckOperandCount(hlo, 1));
|
||||
TF_ASSIGN_OR_RETURN(const Shape expected, ShapeInference::InferCholeskyShape(
|
||||
hlo->operand(0)->shape()));
|
||||
return CheckShape(hlo, expected);
|
||||
}
|
||||
|
||||
Status ShapeVerifier::HandleAllReduce(HloInstruction* crs) {
|
||||
std::vector<const Shape*> operand_shapes;
|
||||
for (const HloInstruction* operand : crs->operands()) {
|
||||
|
@ -52,6 +52,7 @@ class ShapeVerifier : public DfsHloVisitor {
|
||||
Status HandleDot(HloInstruction* dot) override;
|
||||
Status HandleConvolution(HloInstruction* convolution) override;
|
||||
Status HandleFft(HloInstruction* fft) override;
|
||||
Status HandleCholesky(HloInstruction* hlo) override;
|
||||
Status HandleTriangularSolve(HloInstruction* hlo) override;
|
||||
Status HandleAllReduce(HloInstruction* crs) override;
|
||||
Status HandleAllToAll(HloInstruction* hlo) override;
|
||||
|
@ -126,6 +126,7 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
|
||||
case HloOpcode::kBatchNormInference:
|
||||
case HloOpcode::kBatchNormTraining:
|
||||
case HloOpcode::kCall:
|
||||
case HloOpcode::kCholesky:
|
||||
case HloOpcode::kConditional:
|
||||
case HloOpcode::kConvolution:
|
||||
case HloOpcode::kAllReduce:
|
||||
|
@ -32,6 +32,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:algebraic_simplifier",
|
||||
"//tensorflow/compiler/xla/service:cholesky_expander",
|
||||
"//tensorflow/compiler/xla/service:compiler",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:dynamic_index_splitter",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/cholesky_expander.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h"
|
||||
#include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
|
||||
@ -80,6 +81,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) {
|
||||
HloPassPipeline pipeline("Interpreter");
|
||||
|
||||
pipeline.AddPass<DynamicIndexSplitter>();
|
||||
pipeline.AddPass<CholeskyExpander>();
|
||||
pipeline.AddPass<TriangularSolveExpander>();
|
||||
pipeline.AddPass<LayoutAssignment>(
|
||||
hlo_module->mutable_entry_computation_layout(),
|
||||
|
@ -2073,6 +2073,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
|
||||
case HloOpcode::kSubtract:
|
||||
case HloOpcode::kTanh:
|
||||
case HloOpcode::kTriangularSolve:
|
||||
case HloOpcode::kCholesky:
|
||||
case HloOpcode::kTupleSelect:
|
||||
case HloOpcode::kWhile:
|
||||
return false;
|
||||
|
@ -1911,6 +1911,14 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferTriangularSolveShape(
|
||||
const Shape& a, const Shape& b, const TriangularSolveOptions& options) {
|
||||
if ((!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) ||
|
||||
a.element_type() != b.element_type()) {
|
||||
return InvalidArgument(
|
||||
"Expected element types in shape to be floating or complex and "
|
||||
"identical for TriangularSolve; got %s and %s.",
|
||||
PrimitiveType_Name(a.element_type()),
|
||||
PrimitiveType_Name(b.element_type()));
|
||||
}
|
||||
if (a.rank() < 2) {
|
||||
return InvalidArgument(
|
||||
"The 'a' argument to TriangularSolve must have rank >= 2, got shape %s",
|
||||
@ -1952,6 +1960,27 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
|
||||
return b;
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferCholeskyShape(
|
||||
const Shape& a) {
|
||||
if (!ShapeUtil::ElementIsFloating(a) && !ShapeUtil::ElementIsComplex(a)) {
|
||||
return InvalidArgument(
|
||||
"Expected element type in shape to be floating or complex for "
|
||||
"Cholesky; got %s.",
|
||||
PrimitiveType_Name(a.element_type()));
|
||||
}
|
||||
if (a.rank() < 2) {
|
||||
return InvalidArgument(
|
||||
"The 'a' argument to Cholesky must have rank >= 2, got shape %s",
|
||||
a.ToString());
|
||||
}
|
||||
if (a.dimensions(a.rank() - 2) != a.dimensions(a.rank() - 1)) {
|
||||
return InvalidArgument(
|
||||
"The two minor dimensions of 'a' must have equal size, got %s.",
|
||||
a.ToString());
|
||||
}
|
||||
return a;
|
||||
}
|
||||
|
||||
/* static */ StatusOr<Shape> ShapeInference::InferAllReduceShape(
|
||||
absl::Span<const Shape* const> operand_shapes) {
|
||||
for (const Shape* operand_shape : operand_shapes) {
|
||||
|
@ -120,6 +120,9 @@ class ShapeInference {
|
||||
static StatusOr<Shape> InferTriangularSolveShape(
|
||||
const Shape& a, const Shape& b, const TriangularSolveOptions& options);
|
||||
|
||||
// Infers the shape produced by the given triangular solve operation.
|
||||
static StatusOr<Shape> InferCholeskyShape(const Shape& a);
|
||||
|
||||
// Infers the shape produced by a cross replica sum with the given operand
|
||||
// shapes.
|
||||
static StatusOr<Shape> InferAllReduceShape(
|
||||
|
@ -2217,3 +2217,23 @@ xla_test(
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "cholesky_test",
|
||||
srcs = ["cholesky_test.cc"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
@ -13,8 +13,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/cholesky.h"
|
||||
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
@ -32,27 +31,27 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
using xla::int64;
|
||||
using CholeskyTest = ClientLibraryTestBase;
|
||||
|
||||
using CholeskyTest = xla::ClientLibraryTestBase;
|
||||
XLA_TEST_F(CholeskyTest, Lower) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
XLA_TEST_F(CholeskyTest, Simple) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array2D<float> a_vals({
|
||||
{4, 6, 8, 10},
|
||||
{6, 45, 54, 63},
|
||||
{8, 54, 146, 166},
|
||||
float nan = std::numeric_limits<float>::quiet_NaN();
|
||||
Array2D<float> a_vals({
|
||||
{4, nan, nan, nan},
|
||||
{6, 45, nan, nan},
|
||||
{8, 54, 146, nan},
|
||||
{10, 63, 166, 310},
|
||||
});
|
||||
|
||||
xla::XlaOp a;
|
||||
XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
xla::Cholesky(a, /*block_size=*/2);
|
||||
LowerTriangle(Cholesky(a, /*lower=*/true));
|
||||
|
||||
xla::Array2D<float> expected({
|
||||
Array2D<float> expected({
|
||||
{2, 0, 0, 0},
|
||||
{3, 6, 0, 0},
|
||||
{4, 7, 9, 0},
|
||||
@ -60,34 +59,62 @@ XLA_TEST_F(CholeskyTest, Simple) {
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CholeskyTest, Upper) {
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
float nan = std::numeric_limits<float>::quiet_NaN();
|
||||
Array2D<float> a_vals({
|
||||
{4, 6, 8, 10},
|
||||
{nan, 45, 54, 63},
|
||||
{nan, nan, 146, 166},
|
||||
{nan, nan, nan, 310},
|
||||
});
|
||||
|
||||
XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
UpperTriangle(Cholesky(a, /*lower=*/false));
|
||||
|
||||
Array2D<float> expected({
|
||||
{2, 3, 4, 5},
|
||||
{0, 6, 7, 8},
|
||||
{0, 0, 9, 10},
|
||||
{0, 0, 0, 11},
|
||||
});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CholeskyTest, Simple2) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array2D<float> a_vals({
|
||||
Array2D<float> a_vals({
|
||||
{16, 24, 8, 12},
|
||||
{24, 61, 82, 48},
|
||||
{8, 82, 456, 106},
|
||||
{12, 48, 106, 62},
|
||||
});
|
||||
|
||||
xla::XlaOp a;
|
||||
XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
xla::Cholesky(a);
|
||||
LowerTriangle(Cholesky(a, /*lower=*/true));
|
||||
|
||||
xla::Array2D<float> expected(
|
||||
{{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}});
|
||||
Array2D<float> expected({{4, 0, 0, 0}, //
|
||||
{6, 5, 0, 0}, //
|
||||
{2, 14, 16, 0}, //
|
||||
{3, 6, 1, 4}});
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(CholeskyTest, SimpleBatched) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array3D<float> a_vals({
|
||||
Array3D<float> a_vals({
|
||||
{
|
||||
{4, 6, 8, 10},
|
||||
{6, 45, 54, 63},
|
||||
@ -102,65 +129,78 @@ XLA_TEST_F(CholeskyTest, SimpleBatched) {
|
||||
},
|
||||
});
|
||||
|
||||
xla::XlaOp a;
|
||||
XlaOp a;
|
||||
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
xla::Cholesky(a);
|
||||
LowerTriangle(Cholesky(a, /*lower=*/true));
|
||||
|
||||
xla::Array3D<float> expected({
|
||||
Array3D<float> expected({
|
||||
{
|
||||
{2, 0, 0, 0},
|
||||
{3, 6, 0, 0},
|
||||
{4, 7, 9, 0},
|
||||
{5, 8, 10, 11},
|
||||
},
|
||||
{{4, 0, 0, 0}, {6, 5, 0, 0}, {2, 14, 16, 0}, {3, 6, 1, 4}},
|
||||
{{4, 0, 0, 0}, //
|
||||
{6, 5, 0, 0}, //
|
||||
{2, 14, 16, 0}, //
|
||||
{3, 6, 1, 4}},
|
||||
});
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, expected, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
using CholeskyTestCase = std::tuple<int64, int64>;
|
||||
using CholeskyTestCase = std::tuple<int64, int64, bool>;
|
||||
|
||||
class RandomCholeskyTest
|
||||
: public xla::ClientLibraryTestBase,
|
||||
: public ClientLibraryTestBase,
|
||||
public ::testing::WithParamInterface<CholeskyTestCase> {};
|
||||
|
||||
XLA_TEST_P(RandomCholeskyTest, Random) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
XlaBuilder builder(TestName());
|
||||
|
||||
auto test_params = GetParam();
|
||||
std::vector<int64> dimensions = {std::get<0>(test_params),
|
||||
std::get<1>(test_params),
|
||||
std::get<1>(test_params)};
|
||||
xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, dimensions);
|
||||
bool lower = std::get<2>(test_params);
|
||||
Shape shape = ShapeUtil::MakeShape(F32, dimensions);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto literal,
|
||||
xla::LiteralUtil::CreateRandomLiteral<xla::F32>(shape, 0.0, 1.0));
|
||||
auto literal, LiteralUtil::CreateRandomLiteral<F32>(shape, 0.0, 1.0));
|
||||
|
||||
auto input = xla::Parameter(&builder, 0, shape, "input");
|
||||
auto input = Parameter(&builder, 0, shape, "input");
|
||||
// Form a random positive definite matrix.
|
||||
auto matrix = xla::BatchDot(input, TransposeInMinorDims(input),
|
||||
xla::PrecisionConfig::HIGHEST);
|
||||
auto matrix =
|
||||
BatchDot(input, TransposeInMinorDims(input), PrecisionConfig::HIGHEST);
|
||||
|
||||
auto cholesky = xla::Cholesky(matrix, /*block_size=*/4);
|
||||
auto cholesky = Triangle(Cholesky(matrix, lower), lower);
|
||||
|
||||
// Verify that ||matrix - cholesky * cholesky_t||_2 ~= 0
|
||||
auto verification = xla::BatchDot(cholesky, TransposeInMinorDims(cholesky),
|
||||
xla::PrecisionConfig::HIGHEST);
|
||||
XlaOp verification;
|
||||
if (lower) {
|
||||
verification = BatchDot(cholesky, TransposeInMinorDims(cholesky),
|
||||
PrecisionConfig::HIGHEST);
|
||||
} else {
|
||||
verification = BatchDot(TransposeInMinorDims(cholesky), cholesky,
|
||||
PrecisionConfig::HIGHEST);
|
||||
}
|
||||
auto delta = matrix - verification;
|
||||
xla::Reduce(delta * delta, xla::ConstantR0<float>(&builder, 0.0),
|
||||
CreateScalarAddComputation(xla::F32, &builder), {0, 1, 2});
|
||||
Reduce(delta * delta, ConstantR0<float>(&builder, 0.0),
|
||||
CreateScalarAddComputation(F32, &builder), {0, 1, 2});
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto input_data, client_->TransferToServer(literal));
|
||||
ComputeAndCompareR0<float>(&builder, 0.0, {input_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(RandomCholeskyTestInstance, RandomCholeskyTest,
|
||||
::testing::Values(CholeskyTestCase{1, 1},
|
||||
CholeskyTestCase{1, 2},
|
||||
CholeskyTestCase{10, 5},
|
||||
CholeskyTestCase{2, 20}));
|
||||
::testing::Values(CholeskyTestCase{1, 1, true},
|
||||
CholeskyTestCase{1, 2, true},
|
||||
CholeskyTestCase{1, 50, true},
|
||||
CholeskyTestCase{1, 50, false},
|
||||
CholeskyTestCase{10, 5, true},
|
||||
CholeskyTestCase{5, 10, false},
|
||||
CholeskyTestCase{2, 20, true}));
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -566,6 +566,12 @@ message TriangularSolveOptions {
|
||||
Transpose transpose_a = 4;
|
||||
}
|
||||
|
||||
message CholeskyOptions {
|
||||
// If true, uses the lower triangle of `a`. If false, uses the upper triangle
|
||||
// of `a`.
|
||||
bool lower = 1;
|
||||
}
|
||||
|
||||
message OpSharding {
|
||||
enum Type {
|
||||
// This sharding is replicated across all devices (implies maximal,
|
||||
|
Loading…
x
Reference in New Issue
Block a user