[XLA] Move QR decomposition out of TF2XLA and into xla/client/lib.
Add a couple of simple C++ tests. PiperOrigin-RevId: 225044584
This commit is contained in:
parent
a3ad14bbd2
commit
1390ba8f78
@ -117,7 +117,6 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
"//tensorflow/compiler/tf2xla/lib:broadcast",
|
"//tensorflow/compiler/tf2xla/lib:broadcast",
|
||||||
"//tensorflow/compiler/tf2xla/lib:qr",
|
|
||||||
"//tensorflow/compiler/tf2xla/lib:random",
|
"//tensorflow/compiler/tf2xla/lib:random",
|
||||||
"//tensorflow/compiler/tf2xla/lib:scatter",
|
"//tensorflow/compiler/tf2xla/lib:scatter",
|
||||||
"//tensorflow/compiler/tf2xla/lib:util",
|
"//tensorflow/compiler/tf2xla/lib:util",
|
||||||
@ -140,6 +139,7 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||||
"//tensorflow/compiler/xla/client/lib:pooling",
|
"//tensorflow/compiler/xla/client/lib:pooling",
|
||||||
"//tensorflow/compiler/xla/client/lib:prng",
|
"//tensorflow/compiler/xla/client/lib:prng",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:qr",
|
||||||
"//tensorflow/compiler/xla/client/lib:sorting",
|
"//tensorflow/compiler/xla/client/lib:sorting",
|
||||||
"//tensorflow/compiler/xla/client/lib:triangular_solve",
|
"//tensorflow/compiler/xla/client/lib:triangular_solve",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
|
@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/lib/qr.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace {
|
namespace {
|
||||||
@ -26,7 +26,7 @@ class QROp : public XlaOpKernel {
|
|||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
|
||||||
}
|
}
|
||||||
void Compile(XlaOpKernelContext* ctx) override {
|
void Compile(XlaOpKernelContext* ctx) override {
|
||||||
auto result = QRDecomposition(ctx->Input(0), full_matrices_);
|
auto result = xla::QRDecomposition(ctx->Input(0), full_matrices_);
|
||||||
if (!result.ok()) {
|
if (!result.ok()) {
|
||||||
ctx->SetStatus(result.status());
|
ctx->SetStatus(result.status());
|
||||||
return;
|
return;
|
||||||
|
@ -46,28 +46,6 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "qr",
|
|
||||||
srcs = ["qr.cc"],
|
|
||||||
hdrs = ["qr.h"],
|
|
||||||
deps = [
|
|
||||||
":util",
|
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
|
||||||
"//tensorflow/compiler/xla:statusor",
|
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
|
||||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
|
||||||
"//tensorflow/compiler/xla/client/lib:constants",
|
|
||||||
"//tensorflow/compiler/xla/client/lib:loops",
|
|
||||||
"//tensorflow/compiler/xla/client/lib:math",
|
|
||||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
|
||||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
|
||||||
"//tensorflow/core:lib",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "scatter",
|
name = "scatter",
|
||||||
srcs = ["scatter.cc"],
|
srcs = ["scatter.cc"],
|
||||||
|
@ -234,6 +234,48 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "qr",
|
||||||
|
srcs = ["qr.cc"],
|
||||||
|
hdrs = ["qr.h"],
|
||||||
|
deps = [
|
||||||
|
":arithmetic",
|
||||||
|
":constants",
|
||||||
|
":loops",
|
||||||
|
":math",
|
||||||
|
":matrix",
|
||||||
|
":slicing",
|
||||||
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
xla_test(
|
||||||
|
name = "qr_test",
|
||||||
|
srcs = ["qr_test.cc"],
|
||||||
|
tags = ["optonly"],
|
||||||
|
deps = [
|
||||||
|
":matrix",
|
||||||
|
":qr",
|
||||||
|
"//tensorflow/compiler/xla:array2d",
|
||||||
|
"//tensorflow/compiler/xla:array3d",
|
||||||
|
"//tensorflow/compiler/xla:literal",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||||
|
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||||
|
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "slicing",
|
name = "slicing",
|
||||||
srcs = ["slicing.cc"],
|
srcs = ["slicing.cc"],
|
||||||
|
@ -13,12 +13,11 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/lib/qr.h"
|
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
|
||||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||||
@ -32,10 +31,18 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace xla {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
|
||||||
|
absl::Span<const int64> ys) {
|
||||||
|
std::vector<int64> output(xs.size() + ys.size());
|
||||||
|
std::copy(xs.begin(), xs.end(), output.begin());
|
||||||
|
std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
// Computes a Householder reflection of the form:
|
// Computes a Householder reflection of the form:
|
||||||
// H = I - tau v v.T.
|
// H = I - tau v v.T.
|
||||||
// such that
|
// such that
|
||||||
@ -65,52 +72,47 @@ namespace {
|
|||||||
// return (v, tau, beta)
|
// return (v, tau, beta)
|
||||||
// TODO(phawkins): LAPACK's xLARFG implementation has code for handling
|
// TODO(phawkins): LAPACK's xLARFG implementation has code for handling
|
||||||
// overflows in the norm/beta calculations. Perhaps do the same here.
|
// overflows in the norm/beta calculations. Perhaps do the same here.
|
||||||
xla::Status House(xla::XlaOp x, xla::XlaOp k,
|
Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
|
||||||
absl::Span<const int64> batch_dims, const int64 m,
|
const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) {
|
||||||
xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) {
|
XlaBuilder* const builder = x.builder();
|
||||||
xla::XlaBuilder* const builder = x.builder();
|
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
||||||
TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
|
const PrimitiveType type = x_shape.element_type();
|
||||||
const xla::PrimitiveType type = x_shape.element_type();
|
|
||||||
|
|
||||||
std::vector<int64> batch_dim_ids(batch_dims.size());
|
std::vector<int64> batch_dim_ids(batch_dims.size());
|
||||||
std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
|
std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
|
||||||
const int64 minor_dim = batch_dims.size();
|
const int64 minor_dim = batch_dims.size();
|
||||||
|
|
||||||
xla::XlaOp zero = xla::ScalarLike(x, 0.0);
|
XlaOp zero = ScalarLike(x, 0.0);
|
||||||
xla::XlaOp one = xla::ScalarLike(x, 1.0);
|
XlaOp one = ScalarLike(x, 1.0);
|
||||||
|
|
||||||
// alpha = x[k]
|
// alpha = x[k]
|
||||||
xla::XlaOp alpha =
|
XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
|
||||||
xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
|
|
||||||
|
|
||||||
// Compute x[k+1:] (padded with zeros in elements 0..k)
|
// Compute x[k+1:] (padded with zeros in elements 0..k)
|
||||||
xla::XlaOp iota = xla::Iota(builder, xla::S32, m);
|
XlaOp iota = Iota(builder, S32, m);
|
||||||
xla::XlaOp x_after_k =
|
XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
|
||||||
xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type),
|
/*broadcast_dimensions=*/{minor_dim});
|
||||||
/*broadcast_dimensions=*/{minor_dim});
|
|
||||||
|
|
||||||
// sigma = np.dot(x[k+1:], x[k+1:])
|
// sigma = np.dot(x[k+1:], x[k+1:])
|
||||||
auto sigma =
|
auto sigma = Reduce(x_after_k * x_after_k, zero,
|
||||||
xla::Reduce(x_after_k * x_after_k, zero,
|
CreateScalarAddComputation(type, builder), {minor_dim});
|
||||||
xla::CreateScalarAddComputation(type, builder), {minor_dim});
|
|
||||||
// mu = np.sqrt(x[k]*x[k] + sigma)
|
// mu = np.sqrt(x[k]*x[k] + sigma)
|
||||||
auto mu = xla::Sqrt(xla::Square(alpha) + sigma);
|
auto mu = Sqrt(Square(alpha) + sigma);
|
||||||
|
|
||||||
auto sigma_is_zero = xla::Eq(sigma, zero);
|
auto sigma_is_zero = Eq(sigma, zero);
|
||||||
|
|
||||||
*beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu);
|
*beta = Select(sigma_is_zero, alpha, -Sign(alpha) * mu);
|
||||||
*tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims),
|
*tau = Select(sigma_is_zero, Broadcast(zero, batch_dims),
|
||||||
(*beta - alpha) / *beta);
|
(*beta - alpha) / *beta);
|
||||||
auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims),
|
auto divisor =
|
||||||
alpha - *beta);
|
Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta);
|
||||||
|
|
||||||
auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type),
|
auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
|
||||||
std::vector<int64>(batch_dims.size(), 1));
|
std::vector<int64>(batch_dims.size(), 1));
|
||||||
|
|
||||||
// Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
|
// Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
|
||||||
// If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
|
// If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
|
||||||
*v = e_k +
|
*v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
|
||||||
xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,90 +145,86 @@ xla::Status House(xla::XlaOp x, xla::XlaOp k,
|
|||||||
// return (q, vs, taus)
|
// return (q, vs, taus)
|
||||||
struct QRBlockResult {
|
struct QRBlockResult {
|
||||||
// The factored R value
|
// The factored R value
|
||||||
xla::XlaOp r;
|
XlaOp r;
|
||||||
|
|
||||||
// Representation of the Householder matrices I - beta v v.T
|
// Representation of the Householder matrices I - beta v v.T
|
||||||
xla::XlaOp taus; // Shape: [..., n]
|
XlaOp taus; // Shape: [..., n]
|
||||||
xla::XlaOp vs; // Shape: [..., m, n]
|
XlaOp vs; // Shape: [..., m, n]
|
||||||
};
|
};
|
||||||
xla::StatusOr<QRBlockResult> QRBlock(
|
StatusOr<QRBlockResult> QRBlock(XlaOp a, PrecisionConfig::Precision precision) {
|
||||||
xla::XlaOp a, xla::PrecisionConfig::Precision precision) {
|
XlaBuilder* builder = a.builder();
|
||||||
xla::XlaBuilder* builder = a.builder();
|
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||||
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
|
const int num_dims = ShapeUtil::Rank(a_shape);
|
||||||
const int num_dims = xla::ShapeUtil::Rank(a_shape);
|
|
||||||
if (num_dims < 2) {
|
if (num_dims < 2) {
|
||||||
return errors::InvalidArgument("Arguments to QR must have rank >= 2: ",
|
return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
|
||||||
num_dims);
|
a_shape.ToString());
|
||||||
}
|
}
|
||||||
xla::PrimitiveType type = a_shape.element_type();
|
PrimitiveType type = a_shape.element_type();
|
||||||
|
|
||||||
const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2);
|
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
|
||||||
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
|
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||||
|
|
||||||
const int64 num_batch_dims = num_dims - 2;
|
const int64 num_batch_dims = num_dims - 2;
|
||||||
std::vector<int64> batch_dims(num_batch_dims);
|
std::vector<int64> batch_dims(num_batch_dims);
|
||||||
for (int i = 0; i < num_batch_dims; ++i) {
|
for (int i = 0; i < num_batch_dims; ++i) {
|
||||||
batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i);
|
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int64> batch_dim_indices(num_batch_dims);
|
std::vector<int64> batch_dim_indices(num_batch_dims);
|
||||||
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
|
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
|
||||||
|
|
||||||
auto qr_body_fn =
|
auto qr_body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
|
||||||
[&](xla::XlaOp j, absl::Span<const xla::XlaOp> values,
|
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||||
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
|
|
||||||
auto a = values[0];
|
auto a = values[0];
|
||||||
auto vs = values[1];
|
auto vs = values[1];
|
||||||
auto taus = values[2];
|
auto taus = values[2];
|
||||||
|
|
||||||
// v, beta = house(a[:, j], j)
|
// v, beta = house(a[:, j], j)
|
||||||
auto x = DynamicSliceInMinorDims(a, {j}, {1});
|
auto x = DynamicSliceInMinorDims(a, {j}, {1});
|
||||||
xla::XlaOp v, tau, beta;
|
XlaOp v, tau, beta;
|
||||||
TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j,
|
TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j,
|
||||||
batch_dims, m, &v, &tau, &beta));
|
batch_dims, m, &v, &tau, &beta));
|
||||||
|
|
||||||
std::vector<int64> shape = batch_dims;
|
std::vector<int64> shape = batch_dims;
|
||||||
shape.push_back(1);
|
shape.push_back(1);
|
||||||
shape.push_back(m);
|
shape.push_back(m);
|
||||||
auto v_broadcast = xla::Reshape(v, shape);
|
auto v_broadcast = Reshape(v, shape);
|
||||||
// a[:, :] -= tau * np.dot(v[:, np.newaxis],
|
// a[:, :] -= tau * np.dot(v[:, np.newaxis],
|
||||||
// np.dot(v[np.newaxis, :], a[:, :]))
|
// np.dot(v[np.newaxis, :], a[:, :]))
|
||||||
auto vva = BatchDot(v_broadcast, a, precision);
|
auto vva = BatchDot(v_broadcast, a, precision);
|
||||||
vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision);
|
vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision);
|
||||||
a = a - xla::Mul(tau, vva,
|
a = a - Mul(tau, vva,
|
||||||
/*broadcast_dimensions=*/batch_dim_indices);
|
/*broadcast_dimensions=*/batch_dim_indices);
|
||||||
|
|
||||||
// It is more precise to populate column 'k' explicitly, rather than
|
// It is more precise to populate column 'k' explicitly, rather than
|
||||||
// computing it implicitly by applying the Householder transformation.
|
// computing it implicitly by applying the Householder transformation.
|
||||||
// a[k,k] = beta
|
// a[k,k] = beta
|
||||||
// a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
|
// a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
|
||||||
auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1});
|
auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1});
|
||||||
auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type);
|
auto predecessor_mask = ConvertElementType(Lt(iota, j), type);
|
||||||
auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type),
|
auto mask = Broadcast(ConvertElementType(Eq(iota, j), type),
|
||||||
std::vector<int64>(batch_dims.size(), 1));
|
std::vector<int64>(batch_dims.size(), 1));
|
||||||
auto new_x =
|
auto new_x = Mul(x, predecessor_mask,
|
||||||
xla::Mul(x, predecessor_mask,
|
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
|
||||||
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
|
Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
|
||||||
xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
|
|
||||||
a = DynamicUpdateSliceInMinorDims(a, new_x, {j});
|
a = DynamicUpdateSliceInMinorDims(a, new_x, {j});
|
||||||
|
|
||||||
// vs[:, j] = v
|
// vs[:, j] = v
|
||||||
vs = DynamicUpdateSliceInMinorDims(
|
vs = DynamicUpdateSliceInMinorDims(
|
||||||
vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j});
|
vs, Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j});
|
||||||
// taus[j] = tau
|
// taus[j] = tau
|
||||||
taus = DynamicUpdateSliceInMinorDims(
|
taus = DynamicUpdateSliceInMinorDims(
|
||||||
taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j});
|
taus, Reshape(tau, ConcatVectors(batch_dims, {1})), {j});
|
||||||
return std::vector<xla::XlaOp>{a, vs, taus};
|
return std::vector<XlaOp>{a, vs, taus};
|
||||||
};
|
};
|
||||||
|
|
||||||
auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape(
|
auto vs = Zeros(
|
||||||
type, ConcatVectors(batch_dims, {m, n})));
|
builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
|
||||||
auto taus = xla::Zeros(
|
auto taus = Zeros(builder,
|
||||||
builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
|
ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto values,
|
TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
|
||||||
xla::ForEachIndex(std::min(m, n), xla::S32, qr_body_fn,
|
{a, vs, taus}, "qr", builder));
|
||||||
{a, vs, taus}, "qr", builder));
|
|
||||||
|
|
||||||
QRBlockResult result;
|
QRBlockResult result;
|
||||||
result.r = values[0];
|
result.r = values[0];
|
||||||
@ -250,24 +248,23 @@ xla::StatusOr<QRBlockResult> QRBlock(
|
|||||||
// return W
|
// return W
|
||||||
// There is no need to return Y since at termination of the loop it is equal to
|
// There is no need to return Y since at termination of the loop it is equal to
|
||||||
// vs.
|
// vs.
|
||||||
xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
|
StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
|
||||||
xla::PrimitiveType type, absl::Span<const int64> batch_dims, xla::XlaOp vs,
|
absl::Span<const int64> batch_dims,
|
||||||
xla::XlaOp taus, int64 m, int64 n,
|
XlaOp vs, XlaOp taus, int64 m, int64 n,
|
||||||
xla::PrecisionConfig::Precision precision) {
|
PrecisionConfig::Precision precision) {
|
||||||
std::vector<int64> batch_dim_indices(batch_dims.size());
|
std::vector<int64> batch_dim_indices(batch_dims.size());
|
||||||
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
|
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
|
||||||
int64 n_index = batch_dims.size() + 1;
|
int64 n_index = batch_dims.size() + 1;
|
||||||
|
|
||||||
auto body_fn =
|
auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
|
||||||
[&](xla::XlaOp j, absl::Span<const xla::XlaOp> values,
|
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||||
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
|
|
||||||
auto w = values[0];
|
auto w = values[0];
|
||||||
auto y = values[1];
|
auto y = values[1];
|
||||||
const auto vs = values[2];
|
const auto vs = values[2];
|
||||||
const auto taus = values[3];
|
const auto taus = values[3];
|
||||||
|
|
||||||
// Want j values in range [1, ... n).
|
// Want j values in range [1, ... n).
|
||||||
j = j + xla::ConstantR0<int32>(builder, 1);
|
j = j + ConstantR0<int32>(builder, 1);
|
||||||
// vs has shape [..., m, 1]
|
// vs has shape [..., m, 1]
|
||||||
auto v = DynamicSliceInMinorDims(vs, {j}, {1});
|
auto v = DynamicSliceInMinorDims(vs, {j}, {1});
|
||||||
// beta has shape [..., 1]
|
// beta has shape [..., 1]
|
||||||
@ -278,31 +275,31 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
|
|||||||
// wyv has shape [..., m, 1]
|
// wyv has shape [..., m, 1]
|
||||||
auto wyv = BatchDot(w, yv, precision);
|
auto wyv = BatchDot(w, yv, precision);
|
||||||
|
|
||||||
auto z = xla::Mul(
|
auto z = Mul(
|
||||||
-beta, v + wyv,
|
-beta, v + wyv,
|
||||||
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
|
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
|
||||||
|
|
||||||
w = DynamicUpdateSliceInMinorDims(w, z, {j});
|
w = DynamicUpdateSliceInMinorDims(w, z, {j});
|
||||||
y = DynamicUpdateSliceInMinorDims(y, v, {j});
|
y = DynamicUpdateSliceInMinorDims(y, v, {j});
|
||||||
|
|
||||||
return std::vector<xla::XlaOp>{w, y, vs, taus};
|
return std::vector<XlaOp>{w, y, vs, taus};
|
||||||
};
|
};
|
||||||
|
|
||||||
xla::XlaBuilder* builder = vs.builder();
|
XlaBuilder* builder = vs.builder();
|
||||||
auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape(
|
auto w = Zeros(builder,
|
||||||
type, ConcatVectors(batch_dims, {m, n})));
|
ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
|
||||||
auto y = w;
|
auto y = w;
|
||||||
auto v = SliceInMinorDims(vs, {0}, {1});
|
auto v = SliceInMinorDims(vs, {0}, {1});
|
||||||
auto beta = SliceInMinorDims(taus, {0}, {1});
|
auto beta = SliceInMinorDims(taus, {0}, {1});
|
||||||
y = UpdateSliceInMinorDims(y, v, {0});
|
y = UpdateSliceInMinorDims(y, v, {0});
|
||||||
auto bv = xla::Mul(
|
auto bv =
|
||||||
-beta, v,
|
Mul(-beta, v,
|
||||||
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
|
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
|
||||||
w = UpdateSliceInMinorDims(w, bv, {0});
|
w = UpdateSliceInMinorDims(w, bv, {0});
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
auto values, xla::ForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus},
|
auto values,
|
||||||
"wy", builder));
|
ForEachIndex(n - 1, S32, body_fn, {w, y, vs, taus}, "wy", builder));
|
||||||
return values[0];
|
return values[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -323,34 +320,34 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
|
|||||||
// return (q, a)
|
// return (q, a)
|
||||||
// TODO(phawkins): consider using UT transformations (in the form I - V U V')
|
// TODO(phawkins): consider using UT transformations (in the form I - V U V')
|
||||||
// rather than WY transformations.
|
// rather than WY transformations.
|
||||||
xla::StatusOr<QRDecompositionResult> QRDecomposition(
|
StatusOr<QRDecompositionResult> QRDecomposition(
|
||||||
xla::XlaOp a, bool full_matrices, int64 block_size,
|
XlaOp a, bool full_matrices, int64 block_size,
|
||||||
xla::PrecisionConfig::Precision precision) {
|
PrecisionConfig::Precision precision) {
|
||||||
xla::XlaBuilder* builder = a.builder();
|
XlaBuilder* builder = a.builder();
|
||||||
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
|
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||||
const int num_dims = xla::ShapeUtil::Rank(a_shape);
|
const int num_dims = ShapeUtil::Rank(a_shape);
|
||||||
if (num_dims < 2) {
|
if (num_dims < 2) {
|
||||||
return errors::InvalidArgument("Arguments to QR must have rank >= 2: ",
|
return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
|
||||||
num_dims);
|
a_shape.ToString());
|
||||||
}
|
}
|
||||||
xla::PrimitiveType type = a_shape.element_type();
|
PrimitiveType type = a_shape.element_type();
|
||||||
|
|
||||||
const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2);
|
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
|
||||||
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
|
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||||
const int64 p = std::min(m, n);
|
const int64 p = std::min(m, n);
|
||||||
|
|
||||||
if (block_size < 1) {
|
if (block_size < 1) {
|
||||||
return errors::InvalidArgument(
|
return InvalidArgument("block_size argument to QR must be >= 1; got %d",
|
||||||
"block_size argument to QR must be >= 1; got ", block_size);
|
block_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
const int64 num_batch_dims = num_dims - 2;
|
const int64 num_batch_dims = num_dims - 2;
|
||||||
std::vector<int64> batch_dims(num_batch_dims);
|
std::vector<int64> batch_dims(num_batch_dims);
|
||||||
for (int i = 0; i < num_batch_dims; ++i) {
|
for (int i = 0; i < num_batch_dims; ++i) {
|
||||||
batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i);
|
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims);
|
auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
|
||||||
for (int64 i = 0; i < p; i += block_size) {
|
for (int64 i = 0; i < p; i += block_size) {
|
||||||
int64 k = std::min(block_size, p - i);
|
int64 k = std::min(block_size, p - i);
|
||||||
|
|
||||||
@ -393,4 +390,4 @@ xla::StatusOr<QRDecompositionResult> QRDecomposition(
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace xla
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_
|
||||||
#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
|
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace xla {
|
||||||
|
|
||||||
// Computes the QR decompositions of a batch of matrices. That is,
|
// Computes the QR decompositions of a batch of matrices. That is,
|
||||||
// given a (batched) matrix a, computes an orthonormal matrix Q and an
|
// given a (batched) matrix a, computes an orthonormal matrix Q and an
|
||||||
@ -29,14 +29,14 @@ namespace tensorflow {
|
|||||||
// the block size to use.
|
// the block size to use.
|
||||||
// TODO(phawkins): handle the complex case.
|
// TODO(phawkins): handle the complex case.
|
||||||
struct QRDecompositionResult {
|
struct QRDecompositionResult {
|
||||||
xla::XlaOp q;
|
XlaOp q;
|
||||||
xla::XlaOp r;
|
XlaOp r;
|
||||||
};
|
};
|
||||||
|
|
||||||
xla::StatusOr<QRDecompositionResult> QRDecomposition(
|
StatusOr<QRDecompositionResult> QRDecomposition(
|
||||||
xla::XlaOp a, bool full_matrices, int64 block_size = 128,
|
XlaOp a, bool full_matrices, int64 block_size = 128,
|
||||||
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
|
PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST);
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace xla
|
||||||
|
|
||||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
|
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_
|
93
tensorflow/compiler/xla/client/lib/qr_test.cc
Normal file
93
tensorflow/compiler/xla/client/lib/qr_test.cc
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/array2d.h"
|
||||||
|
#include "tensorflow/compiler/xla/array3d.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/literal.h"
|
||||||
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||||
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
using QrTest = xla::ClientLibraryTestBase;
|
||||||
|
|
||||||
|
XLA_TEST_F(QrTest, Simple) {
|
||||||
|
xla::XlaBuilder builder(TestName());
|
||||||
|
|
||||||
|
xla::Array2D<float> a_vals({
|
||||||
|
{4, 6, 8, 10},
|
||||||
|
{6, 45, 54, 63},
|
||||||
|
{8, 54, 146, 166},
|
||||||
|
{10, 63, 166, 310},
|
||||||
|
});
|
||||||
|
|
||||||
|
xla::XlaOp a;
|
||||||
|
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto result,
|
||||||
|
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2));
|
||||||
|
|
||||||
|
// Verifies that the decomposition composes back to the original matrix.
|
||||||
|
//
|
||||||
|
// This isn't a terribly demanding test, (e.g., we should verify that Q is
|
||||||
|
// orthonormal and R is upper-triangular) but it's awkward to write such tests
|
||||||
|
// without more linear algebra libraries. It's easier to test the numerics
|
||||||
|
// from Python, anyway, where we have access to numpy and scipy.
|
||||||
|
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
|
||||||
|
|
||||||
|
ComputeAndCompareR2<float>(&builder, a_vals, {a_data.get()},
|
||||||
|
xla::ErrorSpec(1e-4, 1e-4));
|
||||||
|
}
|
||||||
|
|
||||||
|
XLA_TEST_F(QrTest, SimpleBatched) {
|
||||||
|
xla::XlaBuilder builder(TestName());
|
||||||
|
|
||||||
|
xla::Array3D<float> a_vals({
|
||||||
|
{
|
||||||
|
{4, 6, 8, 10},
|
||||||
|
{6, 45, 54, 63},
|
||||||
|
{8, 54, 146, 166},
|
||||||
|
{10, 63, 166, 310},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
{16, 24, 8, 12},
|
||||||
|
{24, 61, 82, 48},
|
||||||
|
{8, 82, 456, 106},
|
||||||
|
{12, 48, 106, 62},
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
xla::XlaOp a;
|
||||||
|
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||||
|
TF_ASSERT_OK_AND_ASSIGN(
|
||||||
|
auto result,
|
||||||
|
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2));
|
||||||
|
|
||||||
|
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
|
||||||
|
|
||||||
|
ComputeAndCompareR3<float>(&builder, a_vals, {a_data.get()},
|
||||||
|
xla::ErrorSpec(1e-4, 1e-4));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
Loading…
Reference in New Issue
Block a user