[XLA] Split QR decomposition along the same lines as LAPACK xGEQRf/xORGQR.

The current QR decomposition computes explicit Q and R matrices. Sometimes it is useful to represent Q implicitly as a product of elementary Householder transformations, rather than explicitly returning their product, for example when computing determinants using the QR decomposition.

This change refactors the QR decomposition into two methods:
* `xla::Qr`, which, like LAPACK's xGEQRF, computes the `R` and the elementary Householder transformations packed into two matrices `a` and `taus`.
* `xla::ProductOfElementaryHouseholderTransformations`, which, like LAPACK's xORGQR/xUNGQR computes the product of elementary Householder transformations to form Q.

This change also fixes a bug in the QR-based implementation of `xla::Logdet` where it sometimes computed an incorrect sign; now the number of Householder transformations is determined explicitly from `taus`.

PiperOrigin-RevId: 359129909
Change-Id: I58919cabd6d6079114ddf32a1edf61c01e010d02
This commit is contained in:
Peter Hawkins 2021-02-23 14:06:33 -08:00 committed by TensorFlower Gardener
parent 61d5053cb8
commit 89b614f023
13 changed files with 296 additions and 162 deletions

View File

@ -42,15 +42,15 @@ class MatrixInverseOp : public XlaOpKernel {
xla::XlaOp input = xla::MaybeTransposeInMinorDims(ctx->Input(0), adjoint_);
// TODO(b/111271662): Using LU decomposition instead of QR should be faster.
auto qr = xla::QRDecomposition(input, /*full_matrices=*/false);
OP_REQUIRES_OK(ctx, qr.status());
xla::XlaOp q, r;
QrExplicit(input, /*full_matrices=*/false, q, r);
xla::XlaOp output = xla::TriangularSolve(
qr.ValueOrDie().r, xla::TransposeInMinorDims(qr.ValueOrDie().q),
/*left_side=*/true,
/*lower=*/false, /*unit_diagonal=*/false,
/*transpose_a=*/
xla::TriangularSolveOptions::NO_TRANSPOSE);
xla::XlaOp output =
xla::TriangularSolve(r, xla::TransposeInMinorDims(q),
/*left_side=*/true,
/*lower=*/false, /*unit_diagonal=*/false,
/*transpose_a=*/
xla::TriangularSolveOptions::NO_TRANSPOSE);
ctx->SetOutput(0, output);
}

View File

@ -46,15 +46,15 @@ class MatrixSolveOp : public XlaOpKernel {
xla::XlaOp rhs = ctx->Input(1);
// TODO(b/111271662): Using LU decomposition instead of QR should be faster.
auto qr = xla::QRDecomposition(matrix, /*full_matrices=*/false);
OP_REQUIRES_OK(ctx, qr.status());
xla::XlaOp q, r;
xla::QrExplicit(matrix, /*full_matrices=*/false, q, r);
xla::XlaOp inv = xla::TriangularSolve(
qr.ValueOrDie().r, xla::TransposeInMinorDims(qr.ValueOrDie().q),
/*left_side=*/true,
/*lower=*/false, /*unit_diagonal=*/false,
/*transpose_a=*/
xla::TriangularSolveOptions::NO_TRANSPOSE);
xla::XlaOp inv =
xla::TriangularSolve(r, xla::TransposeInMinorDims(q),
/*left_side=*/true,
/*lower=*/false, /*unit_diagonal=*/false,
/*transpose_a=*/
xla::TriangularSolveOptions::NO_TRANSPOSE);
xla::XlaOp output =
xla::BatchDot(inv, adjoint_, rhs,

View File

@ -26,13 +26,10 @@ class QROp : public XlaOpKernel {
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
}
void Compile(XlaOpKernelContext* ctx) override {
auto result = xla::QRDecomposition(ctx->Input(0), full_matrices_);
if (!result.ok()) {
ctx->SetStatus(result.status());
return;
}
ctx->SetOutput(0, result.ValueOrDie().q);
ctx->SetOutput(1, result.ValueOrDie().r);
xla::XlaOp q, r;
xla::QrExplicit(ctx->Input(0), full_matrices_, q, r);
ctx->SetOutput(0, q);
ctx->SetOutput(1, r);
}
private:

View File

@ -292,6 +292,7 @@ xla_test(
srcs = ["qr_test.cc"],
tags = ["optonly"],
deps = [
":constants",
":matrix",
":qr",
"//tensorflow/compiler/xla:array2d",

View File

@ -39,40 +39,23 @@ namespace xla {
XlaOp LogDet(XlaOp a) {
return a.builder()->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, a.builder()->GetShape(a));
// Compute the number of Householder transformations required on 'a' by
// determining the number of rows in 'a' that are already triangular. The
// determinant of Q is -1 ^ (number of Householder transfomations)
auto rows = Iota(a.builder(), ShapeUtil::ChangeElementType(a_shape, S32),
a_shape.rank() - 2);
auto cols = Iota(a.builder(), ShapeUtil::ChangeElementType(a_shape, S32),
a_shape.rank() - 1);
auto in_lower_triangle = Lt(cols, rows);
auto is_zero = Eq(a, ScalarLike(a, 0));
auto num_zeros_in_triangle_per_row = Einsum(
ConvertElementType(And(in_lower_triangle, is_zero), S32), "...a->...");
TF_ASSIGN_OR_RETURN(auto row_shape,
a.builder()->GetShape(num_zeros_in_triangle_per_row));
rows = Iota(a.builder(), row_shape, row_shape.rank() - 1);
auto num_triangle_rows =
Einsum(ConvertElementType(Eq(rows, num_zeros_in_triangle_per_row), S32),
"...a->...");
auto num_rows =
ScalarLike(num_triangle_rows, a_shape.dimensions(a_shape.rank() - 2));
auto qr = Qr(a);
TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, true));
// Get the and log of the determinant based on the values along the diagonal
// of R.
auto log_abs_det = Einsum(Log(Abs(qr.r)), "...aa->...");
// Get the sign and logarithm of the determinant based on the values along
// the diagonal of R and the number of zeros in taus.
auto log_abs_det = Einsum(Log(Abs(qr.q_and_r)), "...aa->...");
auto sign_diag = Reduce(
Sign(Einsum(qr.r, "...aa->...a")),
Sign(Einsum(qr.q_and_r, "...aa->...a")),
One(a.builder(), a_shape.element_type()),
CreateScalarMultiplyComputation(a_shape.element_type(), a.builder()),
{a_shape.rank() - 2});
return sign_diag * log_abs_det *
Select(ConvertElementType(Rem(num_rows - num_triangle_rows,
ScalarLike(num_triangle_rows, 2)),
PRED),
ScalarLike(sign_diag, -1.0), ScalarLike(sign_diag, 1.0));
auto sign_taus = Reduce(
Select(Eq(qr.taus, ZerosLike(qr.taus)), FullLike(qr.taus, -1),
FullLike(qr.taus, 1)),
One(a.builder(), a_shape.element_type()),
CreateScalarMultiplyComputation(a_shape.element_type(), a.builder()),
{a_shape.rank() - 2});
return sign_diag * log_abs_det * sign_taus;
});
}

View File

@ -61,7 +61,7 @@ XLA_TEST_F(LogDetTest, SimpleTriangle) {
{4, 6, 8, 320},
});
float expected = -15.9131355f;
float expected = 15.9131355f;
xla::XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);

View File

@ -33,59 +33,115 @@ limitations under the License.
namespace xla {
namespace {
QrDecomposition Qr(XlaOp a) {
auto result = [&]() -> StatusOr<QrDecomposition> {
XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int num_dims = a_shape.rank();
if (num_dims < 2) {
return InvalidArgument(
"Arguments to QR must have rank >= 2: got shape %s",
a_shape.ToString());
}
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
std::vector<int64> taus_dims(a_shape.dimensions().begin(),
a_shape.dimensions().end());
taus_dims.pop_back();
taus_dims.back() = std::min(m, n);
auto taus_shape = ShapeUtil::MakeShape(a_shape.element_type(), taus_dims);
Shape qr_shape = ShapeUtil::MakeTupleShape({a_shape, taus_shape});
auto qr = CustomCall(a.builder(), "Qr", {a}, qr_shape);
a = GetTupleElement(qr, 0);
auto taus = GetTupleElement(qr, 1);
} // namespace
// Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan.
// def qr_blocked(a, block_size):
// m = a.shape[0]
// n = a.shape[1]
// q = np.eye(m)
// for i in xrange(0, min(m, n), block_size):
// k = min(block_size, min(m, n) - s)
// (a, taus) = qr(a[i:, i:i+k])
// y = np.eye(m, n) + np.tril(a, -1)
// t = CompactWYRepresentation(vs, taus, m-i, k)
// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
// return (q, a)
StatusOr<QRDecompositionResult> QRDecomposition(
XlaOp a, bool full_matrices, int64 block_size,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int num_dims = a_shape.rank();
if (num_dims < 2) {
return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
a_shape.ToString());
return QrDecomposition{a, taus};
}();
if (!result.ok()) {
XlaOp error = a.builder()->ReportError(result.status());
return QrDecomposition{error, error};
}
return result.ValueOrDie();
}
XlaOp ProductOfElementaryHouseholderReflectors(XlaOp a, XlaOp taus) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
TF_ASSIGN_OR_RETURN(Shape taus_shape, builder->GetShape(taus));
if (a_shape.rank() < 2) {
return InvalidArgument(
"Matrix `a` must have >= 2 dimensions: got shape %s",
a_shape.ToString());
}
if (taus_shape.rank() + 1 != a_shape.rank()) {
return InvalidArgument(
"Matrix `taus` must have one fewer dimension than `a`: got shapes "
"%s and %s",
taus_shape.ToString(), a_shape.ToString());
}
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
if (m < n) {
return InvalidArgument(
"Argument to product of elementary Householder "
"reflectors must have m >= n, got shape %s",
a_shape.ToString());
}
absl::Span<const int64> a_batch_dims =
absl::MakeConstSpan(a_shape.dimensions().begin(),
a_shape.dimensions().begin() + a_shape.rank() - 2);
absl::Span<const int64> taus_batch_dims = absl::MakeConstSpan(
taus_shape.dimensions().begin(),
taus_shape.dimensions().begin() + taus_shape.rank() - 1);
const int64 k = ShapeUtil::GetDimension(taus_shape, -1);
if (a_shape.element_type() != taus_shape.element_type() ||
a_batch_dims != taus_batch_dims || k > n) {
return InvalidArgument("Invalid shape for `taus`, got a=%s and taus=%s",
taus_shape.ToString(), a_shape.ToString());
}
return CustomCall(a.builder(), "ProductOfElementaryHouseholderReflectors",
{a, taus}, a_shape);
});
}
void QrExplicit(XlaOp a, bool full_matrices, XlaOp& q, XlaOp& r) {
StatusOr<Shape> a_shape_or = a.builder()->GetShape(a);
if (!a_shape_or.ok()) {
q = a.builder()->ReportError(a_shape_or.status());
r = q;
return;
}
Shape a_shape = a_shape_or.ValueOrDie();
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
const int64 p = std::min(m, n);
if (block_size < 1) {
return InvalidArgument("block_size argument to QR must be >= 1; got %d",
block_size);
}
Shape q_shape = a_shape;
q_shape.mutable_dimensions().back() = m;
Shape qr_shape = ShapeUtil::MakeTupleShape({q_shape, a_shape});
auto qr = CustomCall(a.builder(), "QrDecomposition", {a}, qr_shape);
auto q = GetTupleElement(qr, 0);
auto r = GetTupleElement(qr, 1);
// full_matrices is false when only a partial result in needed. Slice to the
// needed dimensions here.
if (!full_matrices) {
auto qr = Qr(a);
if (full_matrices) {
XlaOp t;
if (m < n) {
t = SliceInMinorDims(qr.q_and_r, {0, 0}, {m, m});
} else {
t = PadInDim(qr.q_and_r, Zero(a.builder(), a_shape.element_type()),
a_shape.dimensions_size() - 1, /*pad_lo=*/0,
/*pad_hi=*/m - n);
}
q = ProductOfElementaryHouseholderReflectors(t, qr.taus);
r = UpperTriangle(qr.q_and_r);
} else {
XlaOp t;
if (m < n) {
t = SliceInMinorDims(qr.q_and_r, {0, 0}, {m, m});
} else {
t = qr.q_and_r;
}
q = ProductOfElementaryHouseholderReflectors(t, qr.taus);
q = SliceInMinorDims(q, {0, 0}, {m, p});
r = SliceInMinorDims(r, {0, 0}, {p, n});
r = UpperTriangle(SliceInMinorDims(qr.q_and_r, {0, 0}, {p, n}));
}
return QRDecompositionResult{q, r};
}
} // namespace xla

View File

@ -25,17 +25,27 @@ namespace xla {
// given a (batched) matrix a, computes an orthonormal matrix Q and an
// upper-triangular matrix R such that a = QR.
// `a` must be a (batched) matrix of size [..., m, n].
// The algorithm implements a blocked QR decomposition; `block_size` is
// the block size to use.
// TODO(phawkins): handle the complex case.
struct QRDecompositionResult {
XlaOp q;
XlaOp r;
struct QrDecomposition {
// A matrix with the same shape as the input matrix `a`, whose upper triangle
// (inclusive of the diagonal) is the matrix R, and whose lower triangle
// (exclusive of the diagonal) contains the elementary Householder reflectors.
// This is the same output format as used by LAPACK's xGEQRF routine.
XlaOp q_and_r;
// A vector of shape [..., min(m, n)] containing the scalar factors of the
// elementary Householder reflectors.
XlaOp taus;
};
StatusOr<QRDecompositionResult> QRDecomposition(
XlaOp a, bool full_matrices, int64 block_size = 128,
PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST);
QrDecomposition Qr(XlaOp a);
// Given `a` and `taus` as returned by `QRDecomposition`, compute the product of
// the elementary Householder reflectors (i.e., the matrix Q of the QR
// decomposition). The equivalent LAPACK routine is xORGQR/xUNGQR.
XlaOp ProductOfElementaryHouseholderReflectors(XlaOp a, XlaOp taus);
// Helper that combines `Qr` and `ProductOfElementaryHouseholderReflectors` to
// compute explicit matrices `q` and `r`.
void QrExplicit(XlaOp a, bool full_matrices, XlaOp& q, XlaOp& r);
} // namespace xla

View File

@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
@ -36,31 +37,44 @@ using QrTest = xla::ClientLibraryTestBase;
XLA_TEST_F(QrTest, Simple) {
// Test fails with TensorFloat-32 enabled
tensorflow::enable_tensor_float_32_execution(false);
xla::XlaBuilder builder(TestName());
xla::Array2D<float> a_vals({
xla::Array2D<float> data({
{4, 6, 8, 10},
{6, 45, 54, 63},
{8, 54, 146, 166},
{10, 63, 166, 310},
});
xla::XlaOp a;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
TF_ASSERT_OK_AND_ASSIGN(
auto result,
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2));
for (bool full_matrices : {false, true}) {
for (xla::int64 m : {3, 4}) {
for (xla::int64 n : {3, 4}) {
xla::XlaBuilder builder(TestName());
xla::XlaOp a, q, r;
xla::Array<float> a_vals = data.Slice({0, 0}, {m, n});
auto a_data = CreateParameter<float>(a_vals, 0, "a", &builder, &a);
xla::QrExplicit(a, full_matrices, q, r);
// Verifies that the decomposition composes back to the original matrix.
//
// This isn't a terribly demanding test, (e.g., we should verify that Q is
// orthonormal and R is upper-triangular) but it's awkward to write such tests
// without more linear algebra libraries. It's easier to test the numerics
// from Python, anyway, where we have access to numpy and scipy.
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
ComputeAndCompareR2<float>(&builder, a_vals, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
// Verifies that the decomposition composes back to the original matrix.
//
// This isn't a terribly demanding test, (e.g., we should verify that Q
// is orthonormal and R is upper-triangular) but it's awkward to write
// such tests without more linear algebra libraries. It's easier to test
// the numerics from Python, anyway, where we have access to numpy and
// scipy.
xla::BatchDot(q, r, xla::PrecisionConfig::HIGHEST);
TF_ASSERT_OK_AND_ASSIGN(xla::Shape q_shape, builder.GetShape(q));
TF_ASSERT_OK_AND_ASSIGN(xla::Shape r_shape, builder.GetShape(r));
EXPECT_EQ(q_shape,
xla::ShapeUtil::MakeShape(
xla::F32, {m, full_matrices ? m : std::min(m, n)}));
EXPECT_EQ(r_shape,
xla::ShapeUtil::MakeShape(
xla::F32, {full_matrices ? m : std::min(m, n), n}));
ComputeAndCompare<float>(&builder, a_vals, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
}
}
}
}
XLA_TEST_F(QrTest, ZeroDiagonal) {
@ -74,11 +88,9 @@ XLA_TEST_F(QrTest, ZeroDiagonal) {
{1, 1, 0},
});
xla::XlaOp a;
xla::XlaOp a, q, r;
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
TF_ASSERT_OK_AND_ASSIGN(
auto result,
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/8));
xla::QrExplicit(a, /*full_matrices=*/true, q, r);
// Verifies that the decomposition composes back to the original matrix.
//
@ -86,7 +98,7 @@ XLA_TEST_F(QrTest, ZeroDiagonal) {
// orthonormal and R is upper-triangular) but it's awkward to write such tests
// without more linear algebra libraries. It's easier to test the numerics
// from Python, anyway, where we have access to numpy and scipy.
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
xla::BatchDot(q, r, xla::PrecisionConfig::HIGHEST);
ComputeAndCompareR2<float>(&builder, a_vals, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));
@ -112,13 +124,11 @@ XLA_TEST_F(QrTest, SimpleBatched) {
},
});
xla::XlaOp a;
xla::XlaOp a, q, r;
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
TF_ASSERT_OK_AND_ASSIGN(
auto result,
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2));
xla::QrExplicit(a, /*full_matrices=*/true, q, r);
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
xla::BatchDot(q, r, xla::PrecisionConfig::HIGHEST);
ComputeAndCompareR3<float>(&builder, a_vals, {a_data.get()},
xla::ErrorSpec(1e-4, 1e-4));

View File

@ -192,8 +192,9 @@ void BuildOpsSubmodule(py::module* m) {
ops.def(
"QR",
[](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices));
return std::make_pair(qr.q, qr.r);
XlaOp q, r;
QrExplicit(a, full_matrices, q, r);
return std::make_pair(q, r);
},
py::arg("operand"), py::arg("full_matrices"));
ops.def(

View File

@ -1964,6 +1964,7 @@ cc_library(
"//tensorflow/compiler/xla/client/lib:loops",
"//tensorflow/compiler/xla/client/lib:math",
"//tensorflow/compiler/xla/client/lib:matrix",
"//tensorflow/compiler/xla/client/lib:qr",
"//tensorflow/compiler/xla/client/lib:slicing",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_map",

View File

@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/lib/loops.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
@ -172,7 +173,7 @@ Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
// a[j+1:, j] = v[j+1:]
// taus[j] = tau
// return (a, taus)
StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
StatusOr<QrDecomposition> QrExpander::QrBlock(
XlaOp a, PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
@ -269,8 +270,8 @@ StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
{a, taus}, "qr", builder));
QrResult result;
result.a = values[0];
QrDecomposition result;
result.q_and_r = values[0];
result.taus = values[1];
return result;
}
@ -372,18 +373,21 @@ StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
}
auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
std::vector<int64> taus_dims = batch_dims;
taus_dims.push_back(p);
auto taus = Zeros(builder, ShapeUtil::MakeShape(type, taus_dims));
for (int64 i = 0; i < p; i += block_size) {
int64 k = std::min(block_size, p - i);
auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
TF_ASSIGN_OR_RETURN(auto qr_block, QrBlock(a_block, precision));
auto y = Add(
IdentityMatrix(builder, type, m - i, k),
Select(TriangleMask(qr_block.a, -1), qr_block.a, ZerosLike(qr_block.a)),
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
auto y = Add(IdentityMatrix(builder, type, m - i, k),
Select(TriangleMask(qr_block.q_and_r, -1), qr_block.q_and_r,
ZerosLike(qr_block.q_and_r)),
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
a = UpdateSliceInMinorDims(a, qr_block.a, {i, i});
a = UpdateSliceInMinorDims(a, qr_block.q_and_r, {i, i});
taus = UpdateSliceInMinorDims(taus, qr_block.taus, {i});
// Compute the I + Y @ T @ Y^t block representation of a product of
// Householder matrices.
@ -401,8 +405,64 @@ StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
a_update = BatchDot(yt, a_update, precision);
a_panel = a_panel + a_update;
a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
}
return Tuple(builder, {a, taus});
}
StatusOr<XlaOp> QrExpander::ProductOfElementaryHouseholderReflectors(
XlaOp a, XlaOp taus, int64 block_size,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
TF_ASSIGN_OR_RETURN(Shape taus_shape, builder->GetShape(taus));
const int num_dims = a_shape.rank();
if (num_dims < 2) {
return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
a_shape.ToString());
}
PrimitiveType type = a_shape.element_type();
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
int64 n = ShapeUtil::GetDimension(a_shape, -1);
const int64 p = ShapeUtil::GetDimension(taus_shape, -1);
if (m < n) {
return InvalidArgument(
"Argument to product of elementary Householder "
"reflectors must have m >= n, got shape %s",
a_shape.ToString());
}
if (block_size < 1) {
return InvalidArgument("block_size argument to QR must be >= 1; got %d",
block_size);
}
const int64 num_batch_dims = num_dims - 2;
std::vector<int64> batch_dims(num_batch_dims);
for (int i = 0; i < num_batch_dims; ++i) {
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
}
auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
for (int64 i = 0; i < p; i += block_size) {
int64 k = std::min(block_size, p - i);
auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
auto y = Add(IdentityMatrix(builder, type, m - i, k),
Select(TriangleMask(a_block, -1), a_block, ZerosLike(a_block)),
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
// Compute the I + Y @ T @ Y^t block representation of a product of
// Householder matrices.
auto taus_block = SliceInMinorDims(taus, {i}, {i + k});
TF_ASSIGN_OR_RETURN(
auto t, CompactWYRepresentation(type, batch_dims, y, taus_block, m - i,
k, precision));
// q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
/*transpose_y=*/true, precision);
auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
auto q_update = BatchDot(q_panel, y, precision);
q_update =
@ -411,19 +471,26 @@ StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
q_panel = q_panel + q_update;
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
}
return Tuple(builder, {q, UpperTriangle(a)});
q = SliceInMinorDims(q, {0, 0}, {m, n});
return q;
}
static const char* kQrCustomCallName = "Qr";
static const char* kHouseholderProductCustomCallName =
"ProductOfElementaryHouseholderReflectors";
bool QrExpander::InstructionMatchesPattern(HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kCustomCall &&
instruction->custom_call_target() == "QrDecomposition";
(instruction->custom_call_target() == kQrCustomCallName ||
instruction->custom_call_target() ==
kHouseholderProductCustomCallName);
}
StatusOr<HloInstruction*> QrExpander::ExpandInstruction(
HloInstruction* instruction) {
const string name =
absl::StrFormat("xla.qr_%s", instruction->operand(0)->shape().ToString());
absl::StrFormat("xla.%s_%s", instruction->custom_call_target(),
instruction->operand(0)->shape().ToString());
HloModule* module = instruction->parent()->parent();
@ -441,13 +508,25 @@ StatusOr<HloInstruction*> QrExpander::ExpandInstruction(
// into our HloModule. Ideally we would avoid the protocol buffer step;
// that is left as an exercise for future work.
XlaBuilder builder(name);
TF_RET_CHECK(instruction->operand_count() >= 1);
XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
TF_ASSIGN_OR_RETURN(
XlaOp l, BuildQrDecomposition(a,
/*block_size=*/128,
XlaOp result;
if (instruction->custom_call_target() == kQrCustomCallName) {
TF_RET_CHECK(instruction->operand_count() == 1);
TF_ASSIGN_OR_RETURN(
result, BuildQrDecomposition(a,
/*block_size=*/128,
/*precision=*/PrecisionConfig::HIGHEST));
} else {
TF_RET_CHECK(instruction->operand_count() == 2);
XlaOp taus =
Parameter(&builder, 1, instruction->operand(1)->shape(), "taus");
TF_ASSIGN_OR_RETURN(result, ProductOfElementaryHouseholderReflectors(
a, taus, /*block_size=*/128,
/*precision=*/PrecisionConfig::HIGHEST));
}
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(l));
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result));
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_QR_EXPANDER_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/client/lib/qr.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
@ -32,17 +33,8 @@ class QrExpander : public OpExpanderPass {
StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* instruction) override;
struct QrResult {
// The upper-triangular matrix R, packed together with the lower-triangular
// elementary Householder reflectors `vs` below the diagonal.
XlaOp a;
// Representation of the Householder matrices I - beta v v.T
XlaOp taus; // Shape: [..., min(m, n)]
};
virtual StatusOr<QrResult> QrBlock(XlaOp a,
PrecisionConfig::Precision precision);
virtual StatusOr<QrDecomposition> QrBlock(
XlaOp a, PrecisionConfig::Precision precision);
virtual StatusOr<XlaOp> CompactWYRepresentation(
PrimitiveType type, absl::Span<const int64> batch_dims, XlaOp vs,
@ -52,6 +44,10 @@ class QrExpander : public OpExpanderPass {
StatusOr<XlaOp> BuildQrDecomposition(XlaOp a, int64 block_size,
PrecisionConfig::Precision precision);
StatusOr<XlaOp> ProductOfElementaryHouseholderReflectors(
XlaOp a, XlaOp taus, int64 block_size,
PrecisionConfig::Precision precision);
// Mapping from op signatures to existing computations.
absl::flat_hash_map<string, HloComputation*> computation_cache_;
};