[XLA] Move QR decomposition out of TF2XLA and into xla/client/lib.

Add a couple of simple C++ tests.

PiperOrigin-RevId: 225044584
This commit is contained in:
Peter Hawkins 2018-12-11 11:57:09 -08:00 committed by TensorFlower Gardener
parent a3ad14bbd2
commit 1390ba8f78
7 changed files with 250 additions and 140 deletions

View File

@ -117,7 +117,6 @@ tf_kernel_library(
"//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:common",
"//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:broadcast",
"//tensorflow/compiler/tf2xla/lib:qr",
"//tensorflow/compiler/tf2xla/lib:random", "//tensorflow/compiler/tf2xla/lib:random",
"//tensorflow/compiler/tf2xla/lib:scatter", "//tensorflow/compiler/tf2xla/lib:scatter",
"//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/tf2xla/lib:util",
@ -140,6 +139,7 @@ tf_kernel_library(
"//tensorflow/compiler/xla/client/lib:matrix", "//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:pooling", "//tensorflow/compiler/xla/client/lib:pooling",
"//tensorflow/compiler/xla/client/lib:prng", "//tensorflow/compiler/xla/client/lib:prng",
"//tensorflow/compiler/xla/client/lib:qr",
"//tensorflow/compiler/xla/client/lib:sorting", "//tensorflow/compiler/xla/client/lib:sorting",
"//tensorflow/compiler/xla/client/lib:triangular_solve", "//tensorflow/compiler/xla/client/lib:triangular_solve",
"//tensorflow/core:framework", "//tensorflow/core:framework",

View File

@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/qr.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -26,7 +26,7 @@ class QROp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
} }
void Compile(XlaOpKernelContext* ctx) override { void Compile(XlaOpKernelContext* ctx) override {
auto result = QRDecomposition(ctx->Input(0), full_matrices_); auto result = xla::QRDecomposition(ctx->Input(0), full_matrices_);
if (!result.ok()) { if (!result.ok()) {
ctx->SetStatus(result.status()); ctx->SetStatus(result.status());
return; return;

View File

@ -46,28 +46,6 @@ cc_library(
], ],
) )
cc_library(
name = "qr",
srcs = ["qr.cc"],
hdrs = ["qr.h"],
deps = [
":util",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/client/lib:arithmetic",
"//tensorflow/compiler/xla/client/lib:constants",
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/core:lib",
],
)
cc_library( cc_library(
name = "scatter", name = "scatter",
srcs = ["scatter.cc"], srcs = ["scatter.cc"],

View File

@ -234,6 +234,48 @@ cc_library(
], ],
) )
cc_library(
name = "qr",
srcs = ["qr.cc"],
hdrs = ["qr.h"],
deps = [
":arithmetic",
":constants",
":loops",
":math",
":matrix",
":slicing",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/core:lib",
],
)
xla_test(
name = "qr_test",
srcs = ["qr_test.cc"],
tags = ["optonly"],
deps = [
":matrix",
":qr",
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:xla_builder",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
)
cc_library( cc_library(
name = "slicing", name = "slicing",
srcs = ["slicing.cc"], srcs = ["slicing.cc"],

View File

@ -13,12 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/compiler/tf2xla/lib/qr.h" #include "tensorflow/compiler/xla/client/lib/qr.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "tensorflow/compiler/tf2xla/lib/util.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/loops.h" #include "tensorflow/compiler/xla/client/lib/loops.h"
@ -32,10 +31,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
namespace tensorflow { namespace xla {
namespace { namespace {
std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
absl::Span<const int64> ys) {
std::vector<int64> output(xs.size() + ys.size());
std::copy(xs.begin(), xs.end(), output.begin());
std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
return output;
}
// Computes a Householder reflection of the form: // Computes a Householder reflection of the form:
// H = I - tau v v.T. // H = I - tau v v.T.
// such that // such that
@ -65,52 +72,47 @@ namespace {
// return (v, tau, beta) // return (v, tau, beta)
// TODO(phawkins): LAPACK's xLARFG implementation has code for handling // TODO(phawkins): LAPACK's xLARFG implementation has code for handling
// overflows in the norm/beta calculations. Perhaps do the same here. // overflows in the norm/beta calculations. Perhaps do the same here.
xla::Status House(xla::XlaOp x, xla::XlaOp k, Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
absl::Span<const int64> batch_dims, const int64 m, const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) {
xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) { XlaBuilder* const builder = x.builder();
xla::XlaBuilder* const builder = x.builder(); TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x)); const PrimitiveType type = x_shape.element_type();
const xla::PrimitiveType type = x_shape.element_type();
std::vector<int64> batch_dim_ids(batch_dims.size()); std::vector<int64> batch_dim_ids(batch_dims.size());
std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
const int64 minor_dim = batch_dims.size(); const int64 minor_dim = batch_dims.size();
xla::XlaOp zero = xla::ScalarLike(x, 0.0); XlaOp zero = ScalarLike(x, 0.0);
xla::XlaOp one = xla::ScalarLike(x, 1.0); XlaOp one = ScalarLike(x, 1.0);
// alpha = x[k] // alpha = x[k]
xla::XlaOp alpha = XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
// Compute x[k+1:] (padded with zeros in elements 0..k) // Compute x[k+1:] (padded with zeros in elements 0..k)
xla::XlaOp iota = xla::Iota(builder, xla::S32, m); XlaOp iota = Iota(builder, S32, m);
xla::XlaOp x_after_k = XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type), /*broadcast_dimensions=*/{minor_dim});
/*broadcast_dimensions=*/{minor_dim});
// sigma = np.dot(x[k+1:], x[k+1:]) // sigma = np.dot(x[k+1:], x[k+1:])
auto sigma = auto sigma = Reduce(x_after_k * x_after_k, zero,
xla::Reduce(x_after_k * x_after_k, zero, CreateScalarAddComputation(type, builder), {minor_dim});
xla::CreateScalarAddComputation(type, builder), {minor_dim});
// mu = np.sqrt(x[k]*x[k] + sigma) // mu = np.sqrt(x[k]*x[k] + sigma)
auto mu = xla::Sqrt(xla::Square(alpha) + sigma); auto mu = Sqrt(Square(alpha) + sigma);
auto sigma_is_zero = xla::Eq(sigma, zero); auto sigma_is_zero = Eq(sigma, zero);
*beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu); *beta = Select(sigma_is_zero, alpha, -Sign(alpha) * mu);
*tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims), *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims),
(*beta - alpha) / *beta); (*beta - alpha) / *beta);
auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims), auto divisor =
alpha - *beta); Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta);
auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type), auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
std::vector<int64>(batch_dims.size(), 1)); std::vector<int64>(batch_dims.size(), 1));
// Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
// If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
*v = e_k + *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
return Status::OK(); return Status::OK();
} }
@ -143,90 +145,86 @@ xla::Status House(xla::XlaOp x, xla::XlaOp k,
// return (q, vs, taus) // return (q, vs, taus)
struct QRBlockResult { struct QRBlockResult {
// The factored R value // The factored R value
xla::XlaOp r; XlaOp r;
// Representation of the Householder matrices I - beta v v.T // Representation of the Householder matrices I - beta v v.T
xla::XlaOp taus; // Shape: [..., n] XlaOp taus; // Shape: [..., n]
xla::XlaOp vs; // Shape: [..., m, n] XlaOp vs; // Shape: [..., m, n]
}; };
xla::StatusOr<QRBlockResult> QRBlock( StatusOr<QRBlockResult> QRBlock(XlaOp a, PrecisionConfig::Precision precision) {
xla::XlaOp a, xla::PrecisionConfig::Precision precision) { XlaBuilder* builder = a.builder();
xla::XlaBuilder* builder = a.builder(); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); const int num_dims = ShapeUtil::Rank(a_shape);
const int num_dims = xla::ShapeUtil::Rank(a_shape);
if (num_dims < 2) { if (num_dims < 2) {
return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
num_dims); a_shape.ToString());
} }
xla::PrimitiveType type = a_shape.element_type(); PrimitiveType type = a_shape.element_type();
const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); const int64 m = ShapeUtil::GetDimension(a_shape, -2);
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); const int64 n = ShapeUtil::GetDimension(a_shape, -1);
const int64 num_batch_dims = num_dims - 2; const int64 num_batch_dims = num_dims - 2;
std::vector<int64> batch_dims(num_batch_dims); std::vector<int64> batch_dims(num_batch_dims);
for (int i = 0; i < num_batch_dims; ++i) { for (int i = 0; i < num_batch_dims; ++i) {
batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
} }
std::vector<int64> batch_dim_indices(num_batch_dims); std::vector<int64> batch_dim_indices(num_batch_dims);
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
auto qr_body_fn = auto qr_body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
[&](xla::XlaOp j, absl::Span<const xla::XlaOp> values, XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
auto a = values[0]; auto a = values[0];
auto vs = values[1]; auto vs = values[1];
auto taus = values[2]; auto taus = values[2];
// v, beta = house(a[:, j], j) // v, beta = house(a[:, j], j)
auto x = DynamicSliceInMinorDims(a, {j}, {1}); auto x = DynamicSliceInMinorDims(a, {j}, {1});
xla::XlaOp v, tau, beta; XlaOp v, tau, beta;
TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j, TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j,
batch_dims, m, &v, &tau, &beta)); batch_dims, m, &v, &tau, &beta));
std::vector<int64> shape = batch_dims; std::vector<int64> shape = batch_dims;
shape.push_back(1); shape.push_back(1);
shape.push_back(m); shape.push_back(m);
auto v_broadcast = xla::Reshape(v, shape); auto v_broadcast = Reshape(v, shape);
// a[:, :] -= tau * np.dot(v[:, np.newaxis], // a[:, :] -= tau * np.dot(v[:, np.newaxis],
// np.dot(v[np.newaxis, :], a[:, :])) // np.dot(v[np.newaxis, :], a[:, :]))
auto vva = BatchDot(v_broadcast, a, precision); auto vva = BatchDot(v_broadcast, a, precision);
vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision); vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision);
a = a - xla::Mul(tau, vva, a = a - Mul(tau, vva,
/*broadcast_dimensions=*/batch_dim_indices); /*broadcast_dimensions=*/batch_dim_indices);
// It is more precise to populate column 'k' explicitly, rather than // It is more precise to populate column 'k' explicitly, rather than
// computing it implicitly by applying the Householder transformation. // computing it implicitly by applying the Householder transformation.
// a[k,k] = beta // a[k,k] = beta
// a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype) // a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1}); auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1});
auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type); auto predecessor_mask = ConvertElementType(Lt(iota, j), type);
auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type), auto mask = Broadcast(ConvertElementType(Eq(iota, j), type),
std::vector<int64>(batch_dims.size(), 1)); std::vector<int64>(batch_dims.size(), 1));
auto new_x = auto new_x = Mul(x, predecessor_mask,
xla::Mul(x, predecessor_mask, /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
a = DynamicUpdateSliceInMinorDims(a, new_x, {j}); a = DynamicUpdateSliceInMinorDims(a, new_x, {j});
// vs[:, j] = v // vs[:, j] = v
vs = DynamicUpdateSliceInMinorDims( vs = DynamicUpdateSliceInMinorDims(
vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j}); vs, Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j});
// taus[j] = tau // taus[j] = tau
taus = DynamicUpdateSliceInMinorDims( taus = DynamicUpdateSliceInMinorDims(
taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j}); taus, Reshape(tau, ConcatVectors(batch_dims, {1})), {j});
return std::vector<xla::XlaOp>{a, vs, taus}; return std::vector<XlaOp>{a, vs, taus};
}; };
auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape( auto vs = Zeros(
type, ConcatVectors(batch_dims, {m, n}))); builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
auto taus = xla::Zeros( auto taus = Zeros(builder,
builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
TF_ASSIGN_OR_RETURN(auto values, TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
xla::ForEachIndex(std::min(m, n), xla::S32, qr_body_fn, {a, vs, taus}, "qr", builder));
{a, vs, taus}, "qr", builder));
QRBlockResult result; QRBlockResult result;
result.r = values[0]; result.r = values[0];
@ -250,24 +248,23 @@ xla::StatusOr<QRBlockResult> QRBlock(
// return W // return W
// There is no need to return Y since at termination of the loop it is equal to // There is no need to return Y since at termination of the loop it is equal to
// vs. // vs.
xla::StatusOr<xla::XlaOp> ComputeWYRepresentation( StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
xla::PrimitiveType type, absl::Span<const int64> batch_dims, xla::XlaOp vs, absl::Span<const int64> batch_dims,
xla::XlaOp taus, int64 m, int64 n, XlaOp vs, XlaOp taus, int64 m, int64 n,
xla::PrecisionConfig::Precision precision) { PrecisionConfig::Precision precision) {
std::vector<int64> batch_dim_indices(batch_dims.size()); std::vector<int64> batch_dim_indices(batch_dims.size());
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
int64 n_index = batch_dims.size() + 1; int64 n_index = batch_dims.size() + 1;
auto body_fn = auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
[&](xla::XlaOp j, absl::Span<const xla::XlaOp> values, XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
auto w = values[0]; auto w = values[0];
auto y = values[1]; auto y = values[1];
const auto vs = values[2]; const auto vs = values[2];
const auto taus = values[3]; const auto taus = values[3];
// Want j values in range [1, ... n). // Want j values in range [1, ... n).
j = j + xla::ConstantR0<int32>(builder, 1); j = j + ConstantR0<int32>(builder, 1);
// vs has shape [..., m, 1] // vs has shape [..., m, 1]
auto v = DynamicSliceInMinorDims(vs, {j}, {1}); auto v = DynamicSliceInMinorDims(vs, {j}, {1});
// beta has shape [..., 1] // beta has shape [..., 1]
@ -278,31 +275,31 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
// wyv has shape [..., m, 1] // wyv has shape [..., m, 1]
auto wyv = BatchDot(w, yv, precision); auto wyv = BatchDot(w, yv, precision);
auto z = xla::Mul( auto z = Mul(
-beta, v + wyv, -beta, v + wyv,
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
w = DynamicUpdateSliceInMinorDims(w, z, {j}); w = DynamicUpdateSliceInMinorDims(w, z, {j});
y = DynamicUpdateSliceInMinorDims(y, v, {j}); y = DynamicUpdateSliceInMinorDims(y, v, {j});
return std::vector<xla::XlaOp>{w, y, vs, taus}; return std::vector<XlaOp>{w, y, vs, taus};
}; };
xla::XlaBuilder* builder = vs.builder(); XlaBuilder* builder = vs.builder();
auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape( auto w = Zeros(builder,
type, ConcatVectors(batch_dims, {m, n}))); ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
auto y = w; auto y = w;
auto v = SliceInMinorDims(vs, {0}, {1}); auto v = SliceInMinorDims(vs, {0}, {1});
auto beta = SliceInMinorDims(taus, {0}, {1}); auto beta = SliceInMinorDims(taus, {0}, {1});
y = UpdateSliceInMinorDims(y, v, {0}); y = UpdateSliceInMinorDims(y, v, {0});
auto bv = xla::Mul( auto bv =
-beta, v, Mul(-beta, v,
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
w = UpdateSliceInMinorDims(w, bv, {0}); w = UpdateSliceInMinorDims(w, bv, {0});
TF_ASSIGN_OR_RETURN( TF_ASSIGN_OR_RETURN(
auto values, xla::ForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus}, auto values,
"wy", builder)); ForEachIndex(n - 1, S32, body_fn, {w, y, vs, taus}, "wy", builder));
return values[0]; return values[0];
} }
@ -323,34 +320,34 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
// return (q, a) // return (q, a)
// TODO(phawkins): consider using UT transformations (in the form I - V U V') // TODO(phawkins): consider using UT transformations (in the form I - V U V')
// rather than WY transformations. // rather than WY transformations.
xla::StatusOr<QRDecompositionResult> QRDecomposition( StatusOr<QRDecompositionResult> QRDecomposition(
xla::XlaOp a, bool full_matrices, int64 block_size, XlaOp a, bool full_matrices, int64 block_size,
xla::PrecisionConfig::Precision precision) { PrecisionConfig::Precision precision) {
xla::XlaBuilder* builder = a.builder(); XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int num_dims = xla::ShapeUtil::Rank(a_shape); const int num_dims = ShapeUtil::Rank(a_shape);
if (num_dims < 2) { if (num_dims < 2) {
return errors::InvalidArgument("Arguments to QR must have rank >= 2: ", return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
num_dims); a_shape.ToString());
} }
xla::PrimitiveType type = a_shape.element_type(); PrimitiveType type = a_shape.element_type();
const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2); const int64 m = ShapeUtil::GetDimension(a_shape, -2);
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1); const int64 n = ShapeUtil::GetDimension(a_shape, -1);
const int64 p = std::min(m, n); const int64 p = std::min(m, n);
if (block_size < 1) { if (block_size < 1) {
return errors::InvalidArgument( return InvalidArgument("block_size argument to QR must be >= 1; got %d",
"block_size argument to QR must be >= 1; got ", block_size); block_size);
} }
const int64 num_batch_dims = num_dims - 2; const int64 num_batch_dims = num_dims - 2;
std::vector<int64> batch_dims(num_batch_dims); std::vector<int64> batch_dims(num_batch_dims);
for (int i = 0; i < num_batch_dims; ++i) { for (int i = 0; i < num_batch_dims; ++i) {
batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i); batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
} }
auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims); auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
for (int64 i = 0; i < p; i += block_size) { for (int64 i = 0; i < p; i += block_size) {
int64 k = std::min(block_size, p - i); int64 k = std::min(block_size, p - i);
@ -393,4 +390,4 @@ xla::StatusOr<QRDecompositionResult> QRDecomposition(
return result; return result;
} }
} // namespace tensorflow } // namespace xla

View File

@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_
#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_
#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xla/xla_data.pb.h"
namespace tensorflow { namespace xla {
// Computes the QR decompositions of a batch of matrices. That is, // Computes the QR decompositions of a batch of matrices. That is,
// given a (batched) matrix a, computes an orthonormal matrix Q and an // given a (batched) matrix a, computes an orthonormal matrix Q and an
@ -29,14 +29,14 @@ namespace tensorflow {
// the block size to use. // the block size to use.
// TODO(phawkins): handle the complex case. // TODO(phawkins): handle the complex case.
struct QRDecompositionResult { struct QRDecompositionResult {
xla::XlaOp q; XlaOp q;
xla::XlaOp r; XlaOp r;
}; };
xla::StatusOr<QRDecompositionResult> QRDecomposition( StatusOr<QRDecompositionResult> QRDecomposition(
xla::XlaOp a, bool full_matrices, int64 block_size = 128, XlaOp a, bool full_matrices, int64 block_size = 128,
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST); PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST);
} // namespace tensorflow } // namespace xla
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_ #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_

View File

@ -0,0 +1,93 @@
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace {
using QrTest = xla::ClientLibraryTestBase;
XLA_TEST_F(QrTest, Simple) {
xla::XlaBuilder builder(TestName());
xla::Array2D<float> a_vals({
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
{10, 63, 166, 310},
});
xla::XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
TF_ASSERT_OK_AND_ASSIGN(
auto result,
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2));
// Verifies that the decomposition composes back to the original matrix.
//
// This isn't a terribly demanding test, (e.g., we should verify that Q is
// orthonormal and R is upper-triangular) but it's awkward to write such tests
// without more linear algebra libraries. It's easier to test the numerics
// from Python, anyway, where we have access to numpy and scipy.
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
ComputeAndCompareR2<float>(&builder, a_vals, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
XLA_TEST_F(QrTest, SimpleBatched) {
xla::XlaBuilder builder(TestName());
xla::Array3D<float> a_vals({
{
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
{10, 63, 166, 310},
},
{
{16, 24, 8, 12},
{24, 61, 82, 48},
{8, 82, 456, 106},
{12, 48, 106, 62},
},
});
xla::XlaOp a;
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
TF_ASSERT_OK_AND_ASSIGN(
auto result,
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2));
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
ComputeAndCompareR3<float>(&builder, a_vals, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
} // namespace