[XLA] Move QR decomposition out of TF2XLA and into xla/client/lib.
Add a couple of simple C++ tests. PiperOrigin-RevId: 225044584
This commit is contained in:
parent
a3ad14bbd2
commit
1390ba8f78
@ -117,7 +117,6 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla/lib:broadcast",
|
||||
"//tensorflow/compiler/tf2xla/lib:qr",
|
||||
"//tensorflow/compiler/tf2xla/lib:random",
|
||||
"//tensorflow/compiler/tf2xla/lib:scatter",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
@ -140,6 +139,7 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/client/lib:pooling",
|
||||
"//tensorflow/compiler/xla/client/lib:prng",
|
||||
"//tensorflow/compiler/xla/client/lib:qr",
|
||||
"//tensorflow/compiler/xla/client/lib:sorting",
|
||||
"//tensorflow/compiler/xla/client/lib:triangular_solve",
|
||||
"//tensorflow/core:framework",
|
||||
|
@ -13,9 +13,9 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/qr.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
@ -26,7 +26,7 @@ class QROp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
|
||||
}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto result = QRDecomposition(ctx->Input(0), full_matrices_);
|
||||
auto result = xla::QRDecomposition(ctx->Input(0), full_matrices_);
|
||||
if (!result.ok()) {
|
||||
ctx->SetStatus(result.status());
|
||||
return;
|
||||
|
@ -46,28 +46,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "qr",
|
||||
srcs = ["qr.cc"],
|
||||
hdrs = ["qr.h"],
|
||||
deps = [
|
||||
":util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/client/lib:loops",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "scatter",
|
||||
srcs = ["scatter.cc"],
|
||||
|
@ -234,6 +234,48 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "qr",
|
||||
srcs = ["qr.cc"],
|
||||
hdrs = ["qr.h"],
|
||||
deps = [
|
||||
":arithmetic",
|
||||
":constants",
|
||||
":loops",
|
||||
":math",
|
||||
":matrix",
|
||||
":slicing",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
xla_test(
|
||||
name = "qr_test",
|
||||
srcs = ["qr_test.cc"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":matrix",
|
||||
":qr",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
"//tensorflow/compiler/xla:array3d",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "slicing",
|
||||
srcs = ["slicing.cc"],
|
||||
|
@ -13,12 +13,11 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/qr.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/lib/util.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||
@ -32,10 +31,18 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<int64> ConcatVectors(absl::Span<const int64> xs,
|
||||
absl::Span<const int64> ys) {
|
||||
std::vector<int64> output(xs.size() + ys.size());
|
||||
std::copy(xs.begin(), xs.end(), output.begin());
|
||||
std::copy(ys.begin(), ys.end(), output.begin() + xs.size());
|
||||
return output;
|
||||
}
|
||||
|
||||
// Computes a Householder reflection of the form:
|
||||
// H = I - tau v v.T.
|
||||
// such that
|
||||
@ -65,52 +72,47 @@ namespace {
|
||||
// return (v, tau, beta)
|
||||
// TODO(phawkins): LAPACK's xLARFG implementation has code for handling
|
||||
// overflows in the norm/beta calculations. Perhaps do the same here.
|
||||
xla::Status House(xla::XlaOp x, xla::XlaOp k,
|
||||
absl::Span<const int64> batch_dims, const int64 m,
|
||||
xla::XlaOp* v, xla::XlaOp* tau, xla::XlaOp* beta) {
|
||||
xla::XlaBuilder* const builder = x.builder();
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape x_shape, builder->GetShape(x));
|
||||
const xla::PrimitiveType type = x_shape.element_type();
|
||||
Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
|
||||
const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) {
|
||||
XlaBuilder* const builder = x.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x));
|
||||
const PrimitiveType type = x_shape.element_type();
|
||||
|
||||
std::vector<int64> batch_dim_ids(batch_dims.size());
|
||||
std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0);
|
||||
const int64 minor_dim = batch_dims.size();
|
||||
|
||||
xla::XlaOp zero = xla::ScalarLike(x, 0.0);
|
||||
xla::XlaOp one = xla::ScalarLike(x, 1.0);
|
||||
XlaOp zero = ScalarLike(x, 0.0);
|
||||
XlaOp one = ScalarLike(x, 1.0);
|
||||
|
||||
// alpha = x[k]
|
||||
xla::XlaOp alpha =
|
||||
xla::Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
|
||||
XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims);
|
||||
|
||||
// Compute x[k+1:] (padded with zeros in elements 0..k)
|
||||
xla::XlaOp iota = xla::Iota(builder, xla::S32, m);
|
||||
xla::XlaOp x_after_k =
|
||||
xla::Mul(x, xla::ConvertElementType(xla::Gt(iota, k), type),
|
||||
/*broadcast_dimensions=*/{minor_dim});
|
||||
XlaOp iota = Iota(builder, S32, m);
|
||||
XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type),
|
||||
/*broadcast_dimensions=*/{minor_dim});
|
||||
|
||||
// sigma = np.dot(x[k+1:], x[k+1:])
|
||||
auto sigma =
|
||||
xla::Reduce(x_after_k * x_after_k, zero,
|
||||
xla::CreateScalarAddComputation(type, builder), {minor_dim});
|
||||
auto sigma = Reduce(x_after_k * x_after_k, zero,
|
||||
CreateScalarAddComputation(type, builder), {minor_dim});
|
||||
// mu = np.sqrt(x[k]*x[k] + sigma)
|
||||
auto mu = xla::Sqrt(xla::Square(alpha) + sigma);
|
||||
auto mu = Sqrt(Square(alpha) + sigma);
|
||||
|
||||
auto sigma_is_zero = xla::Eq(sigma, zero);
|
||||
auto sigma_is_zero = Eq(sigma, zero);
|
||||
|
||||
*beta = xla::Select(sigma_is_zero, alpha, -xla::Sign(alpha) * mu);
|
||||
*tau = xla::Select(sigma_is_zero, xla::Broadcast(zero, batch_dims),
|
||||
(*beta - alpha) / *beta);
|
||||
auto divisor = xla::Select(sigma_is_zero, xla::Broadcast(one, batch_dims),
|
||||
alpha - *beta);
|
||||
*beta = Select(sigma_is_zero, alpha, -Sign(alpha) * mu);
|
||||
*tau = Select(sigma_is_zero, Broadcast(zero, batch_dims),
|
||||
(*beta - alpha) / *beta);
|
||||
auto divisor =
|
||||
Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta);
|
||||
|
||||
auto e_k = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, k), type),
|
||||
std::vector<int64>(batch_dims.size(), 1));
|
||||
auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type),
|
||||
std::vector<int64>(batch_dims.size(), 1));
|
||||
|
||||
// Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor
|
||||
// If sigma is zero, x[k+1:] is zero, so use any non-zero divisor.
|
||||
*v = e_k +
|
||||
xla::Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
|
||||
*v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -143,90 +145,86 @@ xla::Status House(xla::XlaOp x, xla::XlaOp k,
|
||||
// return (q, vs, taus)
|
||||
struct QRBlockResult {
|
||||
// The factored R value
|
||||
xla::XlaOp r;
|
||||
XlaOp r;
|
||||
|
||||
// Representation of the Householder matrices I - beta v v.T
|
||||
xla::XlaOp taus; // Shape: [..., n]
|
||||
xla::XlaOp vs; // Shape: [..., m, n]
|
||||
XlaOp taus; // Shape: [..., n]
|
||||
XlaOp vs; // Shape: [..., m, n]
|
||||
};
|
||||
xla::StatusOr<QRBlockResult> QRBlock(
|
||||
xla::XlaOp a, xla::PrecisionConfig::Precision precision) {
|
||||
xla::XlaBuilder* builder = a.builder();
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
|
||||
const int num_dims = xla::ShapeUtil::Rank(a_shape);
|
||||
StatusOr<QRBlockResult> QRBlock(XlaOp a, PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int num_dims = ShapeUtil::Rank(a_shape);
|
||||
if (num_dims < 2) {
|
||||
return errors::InvalidArgument("Arguments to QR must have rank >= 2: ",
|
||||
num_dims);
|
||||
return InvalidArgument("Argument to QR must have rank >= 2; got shape %s",
|
||||
a_shape.ToString());
|
||||
}
|
||||
xla::PrimitiveType type = a_shape.element_type();
|
||||
PrimitiveType type = a_shape.element_type();
|
||||
|
||||
const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2);
|
||||
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
|
||||
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
|
||||
const int64 num_batch_dims = num_dims - 2;
|
||||
std::vector<int64> batch_dims(num_batch_dims);
|
||||
for (int i = 0; i < num_batch_dims; ++i) {
|
||||
batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i);
|
||||
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
|
||||
}
|
||||
|
||||
std::vector<int64> batch_dim_indices(num_batch_dims);
|
||||
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
|
||||
|
||||
auto qr_body_fn =
|
||||
[&](xla::XlaOp j, absl::Span<const xla::XlaOp> values,
|
||||
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
|
||||
auto qr_body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
|
||||
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
auto a = values[0];
|
||||
auto vs = values[1];
|
||||
auto taus = values[2];
|
||||
|
||||
// v, beta = house(a[:, j], j)
|
||||
auto x = DynamicSliceInMinorDims(a, {j}, {1});
|
||||
xla::XlaOp v, tau, beta;
|
||||
TF_RETURN_IF_ERROR(House(xla::Collapse(x, {num_dims - 2, num_dims - 1}), j,
|
||||
XlaOp v, tau, beta;
|
||||
TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j,
|
||||
batch_dims, m, &v, &tau, &beta));
|
||||
|
||||
std::vector<int64> shape = batch_dims;
|
||||
shape.push_back(1);
|
||||
shape.push_back(m);
|
||||
auto v_broadcast = xla::Reshape(v, shape);
|
||||
auto v_broadcast = Reshape(v, shape);
|
||||
// a[:, :] -= tau * np.dot(v[:, np.newaxis],
|
||||
// np.dot(v[np.newaxis, :], a[:, :]))
|
||||
auto vva = BatchDot(v_broadcast, a, precision);
|
||||
vva = BatchDot(TransposeInMinorDims(v_broadcast), vva, precision);
|
||||
a = a - xla::Mul(tau, vva,
|
||||
/*broadcast_dimensions=*/batch_dim_indices);
|
||||
a = a - Mul(tau, vva,
|
||||
/*broadcast_dimensions=*/batch_dim_indices);
|
||||
|
||||
// It is more precise to populate column 'k' explicitly, rather than
|
||||
// computing it implicitly by applying the Householder transformation.
|
||||
// a[k,k] = beta
|
||||
// a[k+1:,k] = np.zeros([m-k-1], dtype=a.dtype)
|
||||
auto iota = xla::Reshape(xla::Iota(a.builder(), xla::S32, m), {m, 1});
|
||||
auto predecessor_mask = xla::ConvertElementType(xla::Lt(iota, j), type);
|
||||
auto mask = xla::Broadcast(xla::ConvertElementType(xla::Eq(iota, j), type),
|
||||
std::vector<int64>(batch_dims.size(), 1));
|
||||
auto new_x =
|
||||
xla::Mul(x, predecessor_mask,
|
||||
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
|
||||
xla::Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
|
||||
auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1});
|
||||
auto predecessor_mask = ConvertElementType(Lt(iota, j), type);
|
||||
auto mask = Broadcast(ConvertElementType(Eq(iota, j), type),
|
||||
std::vector<int64>(batch_dims.size(), 1));
|
||||
auto new_x = Mul(x, predecessor_mask,
|
||||
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) +
|
||||
Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices);
|
||||
a = DynamicUpdateSliceInMinorDims(a, new_x, {j});
|
||||
|
||||
// vs[:, j] = v
|
||||
vs = DynamicUpdateSliceInMinorDims(
|
||||
vs, xla::Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j});
|
||||
vs, Reshape(v, ConcatVectors(batch_dims, {m, 1})), {j});
|
||||
// taus[j] = tau
|
||||
taus = DynamicUpdateSliceInMinorDims(
|
||||
taus, xla::Reshape(tau, ConcatVectors(batch_dims, {1})), {j});
|
||||
return std::vector<xla::XlaOp>{a, vs, taus};
|
||||
taus, Reshape(tau, ConcatVectors(batch_dims, {1})), {j});
|
||||
return std::vector<XlaOp>{a, vs, taus};
|
||||
};
|
||||
|
||||
auto vs = xla::Zeros(builder, xla::ShapeUtil::MakeShape(
|
||||
type, ConcatVectors(batch_dims, {m, n})));
|
||||
auto taus = xla::Zeros(
|
||||
builder, xla::ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
|
||||
auto vs = Zeros(
|
||||
builder, ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
|
||||
auto taus = Zeros(builder,
|
||||
ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n})));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto values,
|
||||
xla::ForEachIndex(std::min(m, n), xla::S32, qr_body_fn,
|
||||
{a, vs, taus}, "qr", builder));
|
||||
TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
|
||||
{a, vs, taus}, "qr", builder));
|
||||
|
||||
QRBlockResult result;
|
||||
result.r = values[0];
|
||||
@ -250,24 +248,23 @@ xla::StatusOr<QRBlockResult> QRBlock(
|
||||
// return W
|
||||
// There is no need to return Y since at termination of the loop it is equal to
|
||||
// vs.
|
||||
xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
|
||||
xla::PrimitiveType type, absl::Span<const int64> batch_dims, xla::XlaOp vs,
|
||||
xla::XlaOp taus, int64 m, int64 n,
|
||||
xla::PrecisionConfig::Precision precision) {
|
||||
StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
|
||||
absl::Span<const int64> batch_dims,
|
||||
XlaOp vs, XlaOp taus, int64 m, int64 n,
|
||||
PrecisionConfig::Precision precision) {
|
||||
std::vector<int64> batch_dim_indices(batch_dims.size());
|
||||
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
|
||||
int64 n_index = batch_dims.size() + 1;
|
||||
|
||||
auto body_fn =
|
||||
[&](xla::XlaOp j, absl::Span<const xla::XlaOp> values,
|
||||
xla::XlaBuilder* builder) -> xla::StatusOr<std::vector<xla::XlaOp>> {
|
||||
auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
|
||||
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
auto w = values[0];
|
||||
auto y = values[1];
|
||||
const auto vs = values[2];
|
||||
const auto taus = values[3];
|
||||
|
||||
// Want j values in range [1, ... n).
|
||||
j = j + xla::ConstantR0<int32>(builder, 1);
|
||||
j = j + ConstantR0<int32>(builder, 1);
|
||||
// vs has shape [..., m, 1]
|
||||
auto v = DynamicSliceInMinorDims(vs, {j}, {1});
|
||||
// beta has shape [..., 1]
|
||||
@ -278,31 +275,31 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
|
||||
// wyv has shape [..., m, 1]
|
||||
auto wyv = BatchDot(w, yv, precision);
|
||||
|
||||
auto z = xla::Mul(
|
||||
auto z = Mul(
|
||||
-beta, v + wyv,
|
||||
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
|
||||
|
||||
w = DynamicUpdateSliceInMinorDims(w, z, {j});
|
||||
y = DynamicUpdateSliceInMinorDims(y, v, {j});
|
||||
|
||||
return std::vector<xla::XlaOp>{w, y, vs, taus};
|
||||
return std::vector<XlaOp>{w, y, vs, taus};
|
||||
};
|
||||
|
||||
xla::XlaBuilder* builder = vs.builder();
|
||||
auto w = xla::Zeros(builder, xla::ShapeUtil::MakeShape(
|
||||
type, ConcatVectors(batch_dims, {m, n})));
|
||||
XlaBuilder* builder = vs.builder();
|
||||
auto w = Zeros(builder,
|
||||
ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
|
||||
auto y = w;
|
||||
auto v = SliceInMinorDims(vs, {0}, {1});
|
||||
auto beta = SliceInMinorDims(taus, {0}, {1});
|
||||
y = UpdateSliceInMinorDims(y, v, {0});
|
||||
auto bv = xla::Mul(
|
||||
-beta, v,
|
||||
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
|
||||
auto bv =
|
||||
Mul(-beta, v,
|
||||
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
|
||||
w = UpdateSliceInMinorDims(w, bv, {0});
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto values, xla::ForEachIndex(n - 1, xla::S32, body_fn, {w, y, vs, taus},
|
||||
"wy", builder));
|
||||
auto values,
|
||||
ForEachIndex(n - 1, S32, body_fn, {w, y, vs, taus}, "wy", builder));
|
||||
return values[0];
|
||||
}
|
||||
|
||||
@ -323,34 +320,34 @@ xla::StatusOr<xla::XlaOp> ComputeWYRepresentation(
|
||||
// return (q, a)
|
||||
// TODO(phawkins): consider using UT transformations (in the form I - V U V')
|
||||
// rather than WY transformations.
|
||||
xla::StatusOr<QRDecompositionResult> QRDecomposition(
|
||||
xla::XlaOp a, bool full_matrices, int64 block_size,
|
||||
xla::PrecisionConfig::Precision precision) {
|
||||
xla::XlaBuilder* builder = a.builder();
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape a_shape, builder->GetShape(a));
|
||||
const int num_dims = xla::ShapeUtil::Rank(a_shape);
|
||||
StatusOr<QRDecompositionResult> QRDecomposition(
|
||||
XlaOp a, bool full_matrices, int64 block_size,
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int num_dims = ShapeUtil::Rank(a_shape);
|
||||
if (num_dims < 2) {
|
||||
return errors::InvalidArgument("Arguments to QR must have rank >= 2: ",
|
||||
num_dims);
|
||||
return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
|
||||
a_shape.ToString());
|
||||
}
|
||||
xla::PrimitiveType type = a_shape.element_type();
|
||||
PrimitiveType type = a_shape.element_type();
|
||||
|
||||
const int64 m = xla::ShapeUtil::GetDimension(a_shape, -2);
|
||||
const int64 n = xla::ShapeUtil::GetDimension(a_shape, -1);
|
||||
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
const int64 p = std::min(m, n);
|
||||
|
||||
if (block_size < 1) {
|
||||
return errors::InvalidArgument(
|
||||
"block_size argument to QR must be >= 1; got ", block_size);
|
||||
return InvalidArgument("block_size argument to QR must be >= 1; got %d",
|
||||
block_size);
|
||||
}
|
||||
|
||||
const int64 num_batch_dims = num_dims - 2;
|
||||
std::vector<int64> batch_dims(num_batch_dims);
|
||||
for (int i = 0; i < num_batch_dims; ++i) {
|
||||
batch_dims[i] = xla::ShapeUtil::GetDimension(a_shape, i);
|
||||
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
|
||||
}
|
||||
|
||||
auto q = xla::Broadcast(xla::IdentityMatrix(builder, type, m, m), batch_dims);
|
||||
auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
|
||||
for (int64 i = 0; i < p; i += block_size) {
|
||||
int64 k = std::min(block_size, p - i);
|
||||
|
||||
@ -393,4 +390,4 @@ xla::StatusOr<QRDecompositionResult> QRDecomposition(
|
||||
return result;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace xla
|
@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace xla {
|
||||
|
||||
// Computes the QR decompositions of a batch of matrices. That is,
|
||||
// given a (batched) matrix a, computes an orthonormal matrix Q and an
|
||||
@ -29,14 +29,14 @@ namespace tensorflow {
|
||||
// the block size to use.
|
||||
// TODO(phawkins): handle the complex case.
|
||||
struct QRDecompositionResult {
|
||||
xla::XlaOp q;
|
||||
xla::XlaOp r;
|
||||
XlaOp q;
|
||||
XlaOp r;
|
||||
};
|
||||
|
||||
xla::StatusOr<QRDecompositionResult> QRDecomposition(
|
||||
xla::XlaOp a, bool full_matrices, int64 block_size = 128,
|
||||
xla::PrecisionConfig::Precision precision = xla::PrecisionConfig::HIGHEST);
|
||||
StatusOr<QRDecompositionResult> QRDecomposition(
|
||||
XlaOp a, bool full_matrices, int64 block_size = 128,
|
||||
PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST);
|
||||
|
||||
} // namespace tensorflow
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_QR_H_
|
||||
#endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_QR_H_
|
93
tensorflow/compiler/xla/client/lib/qr_test.cc
Normal file
93
tensorflow/compiler/xla/client/lib/qr_test.cc
Normal file
@ -0,0 +1,93 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/array3d.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
|
||||
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
|
||||
#include "tensorflow/compiler/xla/tests/test_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
|
||||
namespace {
|
||||
|
||||
using QrTest = xla::ClientLibraryTestBase;
|
||||
|
||||
XLA_TEST_F(QrTest, Simple) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array2D<float> a_vals({
|
||||
{4, 6, 8, 10},
|
||||
{6, 45, 54, 63},
|
||||
{8, 54, 146, 166},
|
||||
{10, 63, 166, 310},
|
||||
});
|
||||
|
||||
xla::XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto result,
|
||||
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2));
|
||||
|
||||
// Verifies that the decomposition composes back to the original matrix.
|
||||
//
|
||||
// This isn't a terribly demanding test, (e.g., we should verify that Q is
|
||||
// orthonormal and R is upper-triangular) but it's awkward to write such tests
|
||||
// without more linear algebra libraries. It's easier to test the numerics
|
||||
// from Python, anyway, where we have access to numpy and scipy.
|
||||
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, a_vals, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
XLA_TEST_F(QrTest, SimpleBatched) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array3D<float> a_vals({
|
||||
{
|
||||
{4, 6, 8, 10},
|
||||
{6, 45, 54, 63},
|
||||
{8, 54, 146, 166},
|
||||
{10, 63, 166, 310},
|
||||
},
|
||||
{
|
||||
{16, 24, 8, 12},
|
||||
{24, 61, 82, 48},
|
||||
{8, 82, 456, 106},
|
||||
{12, 48, 106, 62},
|
||||
},
|
||||
});
|
||||
|
||||
xla::XlaOp a;
|
||||
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto result,
|
||||
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2));
|
||||
|
||||
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, a_vals, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
|
||||
} // namespace
|
Loading…
Reference in New Issue
Block a user