[XLA:TPU] Add a TPU implementation of LU decomposition using a CustomCall.
PiperOrigin-RevId: 331740238 Change-Id: Ib6ca3a9388b5ce21dafd15afcf1bde6a99dc2209
This commit is contained in:
parent
80a15264e0
commit
d59e3e53cf
@ -309,6 +309,19 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "lu_decomposition",
|
||||
srcs = ["lu_decomposition.cc"],
|
||||
hdrs = ["lu_decomposition.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "slicing",
|
||||
srcs = ["slicing.cc"],
|
||||
|
57
tensorflow/compiler/xla/client/lib/lu_decomposition.cc
Normal file
57
tensorflow/compiler/xla/client/lib/lu_decomposition.cc
Normal file
@ -0,0 +1,57 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/lu_decomposition.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
LuDecompositionResult LuDecomposition(XlaOp a) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
XlaOp result = builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int ndims = a_shape.rank();
|
||||
TF_RET_CHECK(ndims >= 2);
|
||||
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
const int num_batch_dims = a_shape.dimensions().size() - 2;
|
||||
const std::vector<int64> batch_dims(
|
||||
a_shape.dimensions().begin(),
|
||||
a_shape.dimensions().begin() + num_batch_dims);
|
||||
|
||||
std::vector<int64> pivot_dims = batch_dims;
|
||||
pivot_dims.push_back(std::min(m, n));
|
||||
std::vector<int64> perm_dims = batch_dims;
|
||||
perm_dims.push_back(m);
|
||||
Shape lu_shape = ShapeUtil::MakeTupleShape(
|
||||
{a_shape, ShapeUtil::MakeShape(S32, pivot_dims),
|
||||
ShapeUtil::MakeShape(S32, perm_dims)});
|
||||
// The TPU compiler has a rewrite pass that lowers an LuDecomposition
|
||||
// CustomCall.
|
||||
// TODO(phawkins): upgrade LU decomposition to a first-class HLO operator
|
||||
// and implement it on other backends.
|
||||
return CustomCall(a.builder(), "LuDecomposition", {a}, lu_shape);
|
||||
});
|
||||
return LuDecompositionResult{GetTupleElement(result, 0),
|
||||
GetTupleElement(result, 1),
|
||||
GetTupleElement(result, 2)};
|
||||
}
|
||||
|
||||
} // namespace xla
|
61
tensorflow/compiler/xla/client/lib/lu_decomposition.h
Normal file
61
tensorflow/compiler/xla/client/lib/lu_decomposition.h
Normal file
@ -0,0 +1,61 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Computes the LU decomposition with partial pivoting of a batch of matrices.
|
||||
//
|
||||
// Given a (batched) matrix a with shape [..., m, n], computes the matrix
|
||||
// decomposition A = P @ L @ U where P is a permutation matrix, L is a
|
||||
// lower-triangular matrix with unit diagonal entries, and U is an
|
||||
// upper-triangular matrix.
|
||||
//
|
||||
// L and U are returned as a single matrix [..., m, n] containing both L and U
|
||||
// packed in the same array. The unit diagonal of L is not represented
|
||||
// explicitly.
|
||||
//
|
||||
// The permutation matrix P is returned in two forms, both as `pivots`, which is
|
||||
// an s32[..., min(m, n)] array that describes a sequence of row-swaps in the
|
||||
// style of LAPACK's xGETRF API, and `permutation`, which is a s32[..., m] array
|
||||
// which gives the permutation to apply to the rows. We return both
|
||||
// representations because they are each useful for different purposes; `pivots`
|
||||
// is useful for computing the sign of a determinant, whereas `permutation` can
|
||||
// be used via a Gather operation to permute the rows of a matrix.
|
||||
//
|
||||
// This method is only implemented on TPU at the moment.
|
||||
// TODO(b/168208200): the implementation only supports F32 arrays. Handle the
|
||||
// complex case.
|
||||
struct LuDecompositionResult {
|
||||
// The LU decomposition, with both L and U packed into an array with shape
|
||||
// [..., m, n].
|
||||
XlaOp lu;
|
||||
// An array of shape s32[..., min(m, n)] containing the pivot rows.
|
||||
XlaOp pivots;
|
||||
// An array of shape s32[..., m], containing an another representation of the
|
||||
// pivots as a permutation.
|
||||
XlaOp permutation;
|
||||
};
|
||||
|
||||
LuDecompositionResult LuDecomposition(XlaOp a);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_LU_DECOMPOSITION_H_
|
@ -288,6 +288,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:comparators",
|
||||
"//tensorflow/compiler/xla/client/lib:lu_decomposition",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:qr",
|
||||
"//tensorflow/compiler/xla/client/lib:self_adjoint_eig",
|
||||
|
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "pybind11/attr.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/comparators.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/lu_decomposition.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/self_adjoint_eig.h"
|
||||
@ -186,6 +187,13 @@ void BuildOpsSubmodule(py::module* m) {
|
||||
return std::make_pair(qr.q, qr.r);
|
||||
},
|
||||
py::arg("operand"), py::arg("full_matrices"));
|
||||
ops.def(
|
||||
"LU",
|
||||
[](XlaOp a) -> StatusOr<std::tuple<XlaOp, XlaOp, XlaOp>> {
|
||||
LuDecompositionResult lu = LuDecomposition(a);
|
||||
return std::make_tuple(lu.lu, lu.pivots, lu.permutation);
|
||||
},
|
||||
py::arg("operand"));
|
||||
ops.def(
|
||||
"Eigh",
|
||||
[](XlaOp a, bool lower, int64 max_iter,
|
||||
|
Loading…
x
Reference in New Issue
Block a user