[XLA] Move QR decomposition out of TF2XLA and into xla/client/lib.

Add a couple of simple C++ tests.

PiperOrigin-RevId: 225044584
This commit is contained in:
Peter Hawkins 2018-12-11 11:57:09 -08:00 committed by TensorFlower Gardener
parent a3ad14bbd2
commit 1390ba8f78
7 changed files with 250 additions and 140 deletions

View File

@ -117,7 +117,6 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:broadcast",
"//tensorflow/compiler/tf2xla/lib:qr",
"//tensorflow/compiler/tf2xla/lib:random",
"//tensorflow/compiler/tf2xla/lib:scatter",
"//tensorflow/compiler/tf2xla/lib:util",
@ -140,6 +139,7 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:pooling",
"//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/compiler/xla/client/lib:qr",
"//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/compiler/xla/client/lib:triangular_solve",
"//tensorflow/core:framework",

View File

@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/qr.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
namespace tensorflow {
namespace {
@ -26,7 +26,7 @@ class QROp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
}
void Compile(XlaOpKernelContext* ctx) override {
auto result = QRDecomposition(ctx->Input(0), full_matrices_);
auto result = xla::QRDecomposition(ctx->Input(0), full_matrices_);
if (!result.ok()) {
ctx->SetStatus(result.status());
return;

View File

@ -46,28 +46,6 @@ cc_library(
],
)
cc_library(
name = "qr",
srcs = ["qr.cc"],
hdrs = ["qr.h"],
deps = [
":util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/core:lib",
],
)
cc_library(
name = "scatter",
srcs = ["scatter.cc"],

View File

@ -234,6 +234,48 @@ cc_library(
],
)
cc_library(
name = "qr",
srcs = ["qr.cc"],
hdrs = ["qr.h"],
deps = [
":arithmetic",
":constants",
":loops",
":math",
":matrix",
":slicing",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:lib",
],
)
xla_test(
name = "qr_test",
srcs = ["qr_test.cc"],
tags = ["optonly"],
deps = [
":matrix",
":qr",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library(
name = "slicing",
srcs = ["slicing.cc"],

View File

@ -13,12 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/qr.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include <memory>
#include <vector>
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h"
@ -32,10 +31,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
namespace xla {
namespace {
std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
absl::Span<const int64> ys) {
std::vector<int64> output(xs.size() + ys.size());
std::copy(xs.begin(), xs.end(), output.begin());
std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
return output;
}
// Computes a Householder reflection of the form:
// H = I - tau v v.T.
// such that
@ -65,52 +72,47 @@ namespace {
// return (v, tau, beta)
// TODO(phawkins): LAPACK's xLARFG implementation has code for handling
// overflows in the norm/beta calculations. Perhaps do the same here.
xla::Status House(xla::XlaOp x, xla::XlaOp k,
absl::Span<const int64> batch_dims, const int64 m,
xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) {
xla::XlaBuilder* const builder = x.builder();
TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
const xla::PrimitiveType type = x_shape.element_type();
Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) {
XlaBuilder* const builder = x.builder();
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
const PrimitiveType type = x_shape.element_type();
std::vector<int64> batch_dim_ids(batch_dims.size());
std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
const int64 minor_dim = batch_dims.size();
xla::XlaOp zero = xla::ScalarLike(x, 0.0);
xla::XlaOp one = xla::ScalarLike(x, 1.0);
XlaOp zero = ScalarLike(x, 0.0);
XlaOp one = ScalarLike(x, 1.0);
// alpha = x[k]
xla::XlaOp alpha =
xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
// Compute x[k+1:] (padded with zeros in elements 0..k)
xla::XlaOp iota = xla::Iota(builder, xla::S32, m);
xla::XlaOp x_after_k =
xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type),
XlaOp iota = Iota(builder, S32, m);
XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
/*broadcast_dimensions=*/{minor_dim});
// sigma = np.dot(x[k+1:], x[k+1:])
auto sigma =
xla::Reduce(x_after_k * x_after_k, zero,
xla::CreateScalarAddComputation(type, builder), {minor_dim});
auto sigma = Reduce(x_after_k * x_after_k, zero,
CreateScalarAddComputation(type, builder), {minor_dim});
// mu = np.sqrt(x[k]*x[k] + sigma)
auto mu = xla::Sqrt(xla::Square(alpha) + sigma);
auto mu = Sqrt(Square(alpha) + sigma);
auto sigma_is_zero = xla::Eq(sigma, zero);
auto sigma_is_zero = Eq(sigma, zero);
*beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu);
*tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims),
*beta = Select(sigma_is_zero, alpha, -Sign(alpha) * mu);
*tau = Select(sigma_is_zero, Broadcast(zero, batch_dims),
(*beta - alpha) / *beta);
auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims),
alpha - *beta);
auto divisor =
Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta);
auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type),
auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
std::vector<int64>(batch_dims.size(), 1));
// Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
// If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
*v = e_k +
xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
*v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
return Status::OK();
}
@ -143,89 +145,85 @@ xla::Status House(xla::XlaOp x, xla::XlaOp k,
// return (q, vs, taus)
struct QRBlockResult {
// The factored R value
xla::XlaOp r;
XlaOp r;
// Representation of the Householder matrices I - beta v v.T
xla::XlaOp taus; // Shape: [..., n]
xla::XlaOp vs; // Shape: [..., m, n]
XlaOp taus; // Shape: [..., n]
XlaOp vs; // Shape: [..., m, n]
};
xla::StatusOr<QRBlockResult> QRBlock(
xla::XlaOp a, xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int num_dims = xla::ShapeUtil::Rank(a_shape);
StatusOr<QRBlockResult> QRBlock(XlaOp a, PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int num_dims = ShapeUtil::Rank(a_shape);
if (num_dims < 2) {
return errors::InvalidArgument("Arguments to QR must have rank >= 2: ",
num_dims);
return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
a_shape.ToString());
}
xla::PrimitiveType type = a_shape.element_type();
PrimitiveType type = a_shape.element_type();
const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2);
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
const int64 num_batch_dims = num_dims - 2;
std::vector<int64> batch_dims(num_batch_dims);
for (int i = 0; i < num_batch_dims; ++i) {
batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i);
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
}
std::vector<int64> batch_dim_indices(num_batch_dims);
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
auto qr_body_fn =
[&](xla::XlaOp j, absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
auto qr_body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto a = values[0];
auto vs = values[1];
auto taus = values[2];
// v, beta = house(a[:, j], j)
auto x = DynamicSliceInMinorDims(a, {j}, {1});
xla::XlaOp v, tau, beta;
TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j,
XlaOp v, tau, beta;
TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j,
batch_dims, m, &v, &tau, &beta));
std::vector<int64> shape = batch_dims;
shape.push_back(1);
shape.push_back(m);
auto v_broadcast = xla::Reshape(v, shape);
auto v_broadcast = Reshape(v, shape);
// a[:, :] -= tau * np.dot(v[:, np.newaxis],
// np.dot(v[np.newaxis, :], a[:, :]))
auto vva = BatchDot(v_broadcast, a, precision);
vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision);
a = a - xla::Mul(tau, vva,
a = a - Mul(tau, vva,
/*broadcast_dimensions=*/batch_dim_indices);
// It is more precise to populate column 'k' explicitly, rather than
// computing it implicitly by applying the Householder transformation.
// a[k,k] = beta
// a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1});
auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type);
auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type),
auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1});
auto predecessor_mask = ConvertElementType(Lt(iota, j), type);
auto mask = Broadcast(ConvertElementType(Eq(iota, j), type),
std::vector<int64>(batch_dims.size(), 1));
auto new_x =
xla::Mul(x, predecessor_mask,
auto new_x = Mul(x, predecessor_mask,
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
a = DynamicUpdateSliceInMinorDims(a, new_x, {j});
// vs[:, j] = v
vs = DynamicUpdateSliceInMinorDims(
vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j});
vs, Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j});
// taus[j] = tau
taus = DynamicUpdateSliceInMinorDims(
taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j});
return std::vector<xla::XlaOp>{a, vs, taus};
taus, Reshape(tau, ConcatVectors(batch_dims, {1})), {j});
return std::vector<XlaOp>{a, vs, taus};
};
auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape(
type, ConcatVectors(batch_dims, {m, n})));
auto taus = xla::Zeros(
builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
auto vs = Zeros(
builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
auto taus = Zeros(builder,
ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
TF_ASSIGN_OR_RETURN(auto values,
xla::ForEachIndex(std::min(m, n), xla::S32, qr_body_fn,
TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
{a, vs, taus}, "qr", builder));
QRBlockResult result;
@ -250,24 +248,23 @@ xla::StatusOr<QRBlockResult> QRBlock(
// return W
// There is no need to return Y since at termination of the loop it is equal to
// vs.
xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
xla::PrimitiveType type, absl::Span<const int64> batch_dims, xla::XlaOp vs,
xla::XlaOp taus, int64 m, int64 n,
xla::PrecisionConfig::Precision precision) {
StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
absl::Span<const int64> batch_dims,
XlaOp vs, XlaOp taus, int64 m, int64 n,
PrecisionConfig::Precision precision) {
std::vector<int64> batch_dim_indices(batch_dims.size());
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
int64 n_index = batch_dims.size() + 1;
auto body_fn =
[&](xla::XlaOp j, absl::Span<const xla::XlaOp> values,
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
auto w = values[0];
auto y = values[1];
const auto vs = values[2];
const auto taus = values[3];
// Want j values in range [1, ... n).
j = j + xla::ConstantR0<int32>(builder, 1);
j = j + ConstantR0<int32>(builder, 1);
// vs has shape [..., m, 1]
auto v = DynamicSliceInMinorDims(vs, {j}, {1});
// beta has shape [..., 1]
@ -278,31 +275,31 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
// wyv has shape [..., m, 1]
auto wyv = BatchDot(w, yv, precision);
auto z = xla::Mul(
auto z = Mul(
-beta, v + wyv,
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
w = DynamicUpdateSliceInMinorDims(w, z, {j});
y = DynamicUpdateSliceInMinorDims(y, v, {j});
return std::vector<xla::XlaOp>{w, y, vs, taus};
return std::vector<XlaOp>{w, y, vs, taus};
};
xla::XlaBuilder* builder = vs.builder();
auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape(
type, ConcatVectors(batch_dims, {m, n})));
XlaBuilder* builder = vs.builder();
auto w = Zeros(builder,
ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
auto y = w;
auto v = SliceInMinorDims(vs, {0}, {1});
auto beta = SliceInMinorDims(taus, {0}, {1});
y = UpdateSliceInMinorDims(y, v, {0});
auto bv = xla::Mul(
-beta, v,
auto bv =
Mul(-beta, v,
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
w = UpdateSliceInMinorDims(w, bv, {0});
TF_ASSIGN_OR_RETURN(
auto values, xla::ForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus},
"wy", builder));
auto values,
ForEachIndex(n - 1, S32, body_fn, {w, y, vs, taus}, "wy", builder));
return values[0];
}
@ -323,34 +320,34 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
// return (q, a)
// TODO(phawkins): consider using UT transformations (in the form I - V U V')
// rather than WY transformations.
xla::StatusOr<QRDecompositionResult> QRDecomposition(
xla::XlaOp a, bool full_matrices, int64 block_size,
xla::PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
const int num_dims = xla::ShapeUtil::Rank(a_shape);
StatusOr<QRDecompositionResult> QRDecomposition(
XlaOp a, bool full_matrices, int64 block_size,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int num_dims = ShapeUtil::Rank(a_shape);
if (num_dims < 2) {
return errors::InvalidArgument("Arguments to QR must have rank >= 2: ",
num_dims);
return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
a_shape.ToString());
}
xla::PrimitiveType type = a_shape.element_type();
PrimitiveType type = a_shape.element_type();
const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2);
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
const int64 p = std::min(m, n);
if (block_size < 1) {
return errors::InvalidArgument(
"block_size argument to QR must be >= 1; got ", block_size);
return InvalidArgument("block_size argument to QR must be >= 1; got %d",
block_size);
}
const int64 num_batch_dims = num_dims - 2;
std::vector<int64> batch_dims(num_batch_dims);
for (int i = 0; i < num_batch_dims; ++i) {
batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i);
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
}
auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims);
auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
for (int64 i = 0; i < p; i += block_size) {
int64 k = std::min(block_size, p - i);
@ -393,4 +390,4 @@ xla::StatusOr<QRDecompositionResult> QRDecomposition(
return result;
}
} // namespace tensorflow
} // namespace xla

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow {
namespace xla {
// Computes the QR decompositions of a batch of matrices. That is,
// given a (batched) matrix a, computes an orthonormal matrix Q and an
@ -29,14 +29,14 @@ namespace tensorflow {
// the block size to use.
// TODO(phawkins): handle the complex case.
struct QRDecompositionResult {
xla::XlaOp q;
xla::XlaOp r;
XlaOp q;
XlaOp r;
};
xla::StatusOr<QRDecompositionResult> QRDecomposition(
xla::XlaOp a, bool full_matrices, int64 block_size = 128,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
StatusOr<QRDecompositionResult> QRDecomposition(
XlaOp a, bool full_matrices, int64 block_size = 128,
PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST);
} // namespace tensorflow
} // namespace xla
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_

View File

@ -0,0 +1,93 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace {
using QrTest = xla::ClientLibraryTestBase;
XLA_TEST_F(QrTest, Simple) {
xla::XlaBuilder builder(TestName());
xla::Array2D<float> a_vals({
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
{10, 63, 166, 310},
});
xla::XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
TF_ASSERT_OK_AND_ASSIGN(
auto result,
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2));
// Verifies that the decomposition composes back to the original matrix.
//
// This isn't a terribly demanding test, (e.g., we should verify that Q is
// orthonormal and R is upper-triangular) but it's awkward to write such tests
// without more linear algebra libraries. It's easier to test the numerics
// from Python, anyway, where we have access to numpy and scipy.
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
ComputeAndCompareR2<float>(&builder, a_vals, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(QrTest, SimpleBatched) {
xla::XlaBuilder builder(TestName());
xla::Array3D<float> a_vals({
{
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
{10, 63, 166, 310},
},
{
{16, 24, 8, 12},
{24, 61, 82, 48},
{8, 82, 456, 106},
{12, 48, 106, 62},
},
});
xla::XlaOp a;
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
TF_ASSERT_OK_AND_ASSIGN(
auto result,
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2));
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
ComputeAndCompareR3<float>(&builder, a_vals, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
} // namespace