[XLA] Split QR decomposition along the same lines as LAPACK xGEQRf/xORGQR.
The current QR decomposition computes explicit Q and R matrices. Sometimes it is useful to represent Q implicitly as a product of elementary Householder transformations, rather than explicitly returning their product, for example when computing determinants using the QR decomposition. This change refactors the QR decomposition into two methods: * `xla::Qr`, which, like LAPACK's xGEQRF, computes the `R` and the elementary Householder transformations packed into two matrices `a` and `taus`. * `xla::ProductOfElementaryHouseholderTransformations`, which, like LAPACK's xORGQR/xUNGQR computes the product of elementary Householder transformations to form Q. This change also fixes a bug in the QR-based implementation of `xla::Logdet` where it sometimes computed an incorrect sign; now the number of Householder transformations is determined explicitly from `taus`. PiperOrigin-RevId: 359129909 Change-Id: I58919cabd6d6079114ddf32a1edf61c01e010d02
This commit is contained in:
parent
61d5053cb8
commit
89b614f023
@ -42,15 +42,15 @@ class MatrixInverseOp : public XlaOpKernel {
|
||||
xla::XlaOp input = xla::MaybeTransposeInMinorDims(ctx->Input(0), adjoint_);
|
||||
|
||||
// TODO(b/111271662): Using LU decomposition instead of QR should be faster.
|
||||
auto qr = xla::QRDecomposition(input, /*full_matrices=*/false);
|
||||
OP_REQUIRES_OK(ctx, qr.status());
|
||||
xla::XlaOp q, r;
|
||||
QrExplicit(input, /*full_matrices=*/false, q, r);
|
||||
|
||||
xla::XlaOp output = xla::TriangularSolve(
|
||||
qr.ValueOrDie().r, xla::TransposeInMinorDims(qr.ValueOrDie().q),
|
||||
/*left_side=*/true,
|
||||
/*lower=*/false, /*unit_diagonal=*/false,
|
||||
/*transpose_a=*/
|
||||
xla::TriangularSolveOptions::NO_TRANSPOSE);
|
||||
xla::XlaOp output =
|
||||
xla::TriangularSolve(r, xla::TransposeInMinorDims(q),
|
||||
/*left_side=*/true,
|
||||
/*lower=*/false, /*unit_diagonal=*/false,
|
||||
/*transpose_a=*/
|
||||
xla::TriangularSolveOptions::NO_TRANSPOSE);
|
||||
ctx->SetOutput(0, output);
|
||||
}
|
||||
|
||||
|
||||
@ -46,15 +46,15 @@ class MatrixSolveOp : public XlaOpKernel {
|
||||
xla::XlaOp rhs = ctx->Input(1);
|
||||
|
||||
// TODO(b/111271662): Using LU decomposition instead of QR should be faster.
|
||||
auto qr = xla::QRDecomposition(matrix, /*full_matrices=*/false);
|
||||
OP_REQUIRES_OK(ctx, qr.status());
|
||||
xla::XlaOp q, r;
|
||||
xla::QrExplicit(matrix, /*full_matrices=*/false, q, r);
|
||||
|
||||
xla::XlaOp inv = xla::TriangularSolve(
|
||||
qr.ValueOrDie().r, xla::TransposeInMinorDims(qr.ValueOrDie().q),
|
||||
/*left_side=*/true,
|
||||
/*lower=*/false, /*unit_diagonal=*/false,
|
||||
/*transpose_a=*/
|
||||
xla::TriangularSolveOptions::NO_TRANSPOSE);
|
||||
xla::XlaOp inv =
|
||||
xla::TriangularSolve(r, xla::TransposeInMinorDims(q),
|
||||
/*left_side=*/true,
|
||||
/*lower=*/false, /*unit_diagonal=*/false,
|
||||
/*transpose_a=*/
|
||||
xla::TriangularSolveOptions::NO_TRANSPOSE);
|
||||
|
||||
xla::XlaOp output =
|
||||
xla::BatchDot(inv, adjoint_, rhs,
|
||||
|
||||
@ -26,13 +26,10 @@ class QROp : public XlaOpKernel {
|
||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("full_matrices", &full_matrices_));
|
||||
}
|
||||
void Compile(XlaOpKernelContext* ctx) override {
|
||||
auto result = xla::QRDecomposition(ctx->Input(0), full_matrices_);
|
||||
if (!result.ok()) {
|
||||
ctx->SetStatus(result.status());
|
||||
return;
|
||||
}
|
||||
ctx->SetOutput(0, result.ValueOrDie().q);
|
||||
ctx->SetOutput(1, result.ValueOrDie().r);
|
||||
xla::XlaOp q, r;
|
||||
xla::QrExplicit(ctx->Input(0), full_matrices_, q, r);
|
||||
ctx->SetOutput(0, q);
|
||||
ctx->SetOutput(1, r);
|
||||
}
|
||||
|
||||
private:
|
||||
|
||||
@ -292,6 +292,7 @@ xla_test(
|
||||
srcs = ["qr_test.cc"],
|
||||
tags = ["optonly"],
|
||||
deps = [
|
||||
":constants",
|
||||
":matrix",
|
||||
":qr",
|
||||
"//tensorflow/compiler/xla:array2d",
|
||||
|
||||
@ -39,40 +39,23 @@ namespace xla {
|
||||
XlaOp LogDet(XlaOp a) {
|
||||
return a.builder()->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, a.builder()->GetShape(a));
|
||||
// Compute the number of Householder transformations required on 'a' by
|
||||
// determining the number of rows in 'a' that are already triangular. The
|
||||
// determinant of Q is -1 ^ (number of Householder transfomations)
|
||||
auto rows = Iota(a.builder(), ShapeUtil::ChangeElementType(a_shape, S32),
|
||||
a_shape.rank() - 2);
|
||||
auto cols = Iota(a.builder(), ShapeUtil::ChangeElementType(a_shape, S32),
|
||||
a_shape.rank() - 1);
|
||||
auto in_lower_triangle = Lt(cols, rows);
|
||||
auto is_zero = Eq(a, ScalarLike(a, 0));
|
||||
auto num_zeros_in_triangle_per_row = Einsum(
|
||||
ConvertElementType(And(in_lower_triangle, is_zero), S32), "...a->...");
|
||||
TF_ASSIGN_OR_RETURN(auto row_shape,
|
||||
a.builder()->GetShape(num_zeros_in_triangle_per_row));
|
||||
rows = Iota(a.builder(), row_shape, row_shape.rank() - 1);
|
||||
auto num_triangle_rows =
|
||||
Einsum(ConvertElementType(Eq(rows, num_zeros_in_triangle_per_row), S32),
|
||||
"...a->...");
|
||||
auto num_rows =
|
||||
ScalarLike(num_triangle_rows, a_shape.dimensions(a_shape.rank() - 2));
|
||||
auto qr = Qr(a);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, true));
|
||||
// Get the and log of the determinant based on the values along the diagonal
|
||||
// of R.
|
||||
auto log_abs_det = Einsum(Log(Abs(qr.r)), "...aa->...");
|
||||
// Get the sign and logarithm of the determinant based on the values along
|
||||
// the diagonal of R and the number of zeros in taus.
|
||||
auto log_abs_det = Einsum(Log(Abs(qr.q_and_r)), "...aa->...");
|
||||
auto sign_diag = Reduce(
|
||||
Sign(Einsum(qr.r, "...aa->...a")),
|
||||
Sign(Einsum(qr.q_and_r, "...aa->...a")),
|
||||
One(a.builder(), a_shape.element_type()),
|
||||
CreateScalarMultiplyComputation(a_shape.element_type(), a.builder()),
|
||||
{a_shape.rank() - 2});
|
||||
return sign_diag * log_abs_det *
|
||||
Select(ConvertElementType(Rem(num_rows - num_triangle_rows,
|
||||
ScalarLike(num_triangle_rows, 2)),
|
||||
PRED),
|
||||
ScalarLike(sign_diag, -1.0), ScalarLike(sign_diag, 1.0));
|
||||
auto sign_taus = Reduce(
|
||||
Select(Eq(qr.taus, ZerosLike(qr.taus)), FullLike(qr.taus, -1),
|
||||
FullLike(qr.taus, 1)),
|
||||
One(a.builder(), a_shape.element_type()),
|
||||
CreateScalarMultiplyComputation(a_shape.element_type(), a.builder()),
|
||||
{a_shape.rank() - 2});
|
||||
return sign_diag * log_abs_det * sign_taus;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@ -61,7 +61,7 @@ XLA_TEST_F(LogDetTest, SimpleTriangle) {
|
||||
{4, 6, 8, 320},
|
||||
});
|
||||
|
||||
float expected = -15.9131355f;
|
||||
float expected = 15.9131355f;
|
||||
|
||||
xla::XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
|
||||
@ -33,59 +33,115 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
|
||||
namespace {
|
||||
QrDecomposition Qr(XlaOp a) {
|
||||
auto result = [&]() -> StatusOr<QrDecomposition> {
|
||||
XlaBuilder* builder = a.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int num_dims = a_shape.rank();
|
||||
if (num_dims < 2) {
|
||||
return InvalidArgument(
|
||||
"Arguments to QR must have rank >= 2: got shape %s",
|
||||
a_shape.ToString());
|
||||
}
|
||||
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
|
||||
std::vector<int64> taus_dims(a_shape.dimensions().begin(),
|
||||
a_shape.dimensions().end());
|
||||
taus_dims.pop_back();
|
||||
taus_dims.back() = std::min(m, n);
|
||||
auto taus_shape = ShapeUtil::MakeShape(a_shape.element_type(), taus_dims);
|
||||
|
||||
Shape qr_shape = ShapeUtil::MakeTupleShape({a_shape, taus_shape});
|
||||
auto qr = CustomCall(a.builder(), "Qr", {a}, qr_shape);
|
||||
a = GetTupleElement(qr, 0);
|
||||
auto taus = GetTupleElement(qr, 1);
|
||||
|
||||
} // namespace
|
||||
|
||||
// Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan.
|
||||
// def qr_blocked(a, block_size):
|
||||
// m = a.shape[0]
|
||||
// n = a.shape[1]
|
||||
// q = np.eye(m)
|
||||
// for i in xrange(0, min(m, n), block_size):
|
||||
// k = min(block_size, min(m, n) - s)
|
||||
// (a, taus) = qr(a[i:, i:i+k])
|
||||
// y = np.eye(m, n) + np.tril(a, -1)
|
||||
// t = CompactWYRepresentation(vs, taus, m-i, k)
|
||||
// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
|
||||
// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
|
||||
// return (q, a)
|
||||
StatusOr<QRDecompositionResult> QRDecomposition(
|
||||
XlaOp a, bool full_matrices, int64 block_size,
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
const int num_dims = a_shape.rank();
|
||||
if (num_dims < 2) {
|
||||
return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
|
||||
a_shape.ToString());
|
||||
return QrDecomposition{a, taus};
|
||||
}();
|
||||
if (!result.ok()) {
|
||||
XlaOp error = a.builder()->ReportError(result.status());
|
||||
return QrDecomposition{error, error};
|
||||
}
|
||||
return result.ValueOrDie();
|
||||
}
|
||||
|
||||
XlaOp ProductOfElementaryHouseholderReflectors(XlaOp a, XlaOp taus) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
TF_ASSIGN_OR_RETURN(Shape taus_shape, builder->GetShape(taus));
|
||||
if (a_shape.rank() < 2) {
|
||||
return InvalidArgument(
|
||||
"Matrix `a` must have >= 2 dimensions: got shape %s",
|
||||
a_shape.ToString());
|
||||
}
|
||||
if (taus_shape.rank() + 1 != a_shape.rank()) {
|
||||
return InvalidArgument(
|
||||
"Matrix `taus` must have one fewer dimension than `a`: got shapes "
|
||||
"%s and %s",
|
||||
taus_shape.ToString(), a_shape.ToString());
|
||||
}
|
||||
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
if (m < n) {
|
||||
return InvalidArgument(
|
||||
"Argument to product of elementary Householder "
|
||||
"reflectors must have m >= n, got shape %s",
|
||||
a_shape.ToString());
|
||||
}
|
||||
absl::Span<const int64> a_batch_dims =
|
||||
absl::MakeConstSpan(a_shape.dimensions().begin(),
|
||||
a_shape.dimensions().begin() + a_shape.rank() - 2);
|
||||
absl::Span<const int64> taus_batch_dims = absl::MakeConstSpan(
|
||||
taus_shape.dimensions().begin(),
|
||||
taus_shape.dimensions().begin() + taus_shape.rank() - 1);
|
||||
const int64 k = ShapeUtil::GetDimension(taus_shape, -1);
|
||||
if (a_shape.element_type() != taus_shape.element_type() ||
|
||||
a_batch_dims != taus_batch_dims || k > n) {
|
||||
return InvalidArgument("Invalid shape for `taus`, got a=%s and taus=%s",
|
||||
taus_shape.ToString(), a_shape.ToString());
|
||||
}
|
||||
return CustomCall(a.builder(), "ProductOfElementaryHouseholderReflectors",
|
||||
{a, taus}, a_shape);
|
||||
});
|
||||
}
|
||||
|
||||
void QrExplicit(XlaOp a, bool full_matrices, XlaOp& q, XlaOp& r) {
|
||||
StatusOr<Shape> a_shape_or = a.builder()->GetShape(a);
|
||||
if (!a_shape_or.ok()) {
|
||||
q = a.builder()->ReportError(a_shape_or.status());
|
||||
r = q;
|
||||
return;
|
||||
}
|
||||
Shape a_shape = a_shape_or.ValueOrDie();
|
||||
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
|
||||
const int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
const int64 p = std::min(m, n);
|
||||
|
||||
if (block_size < 1) {
|
||||
return InvalidArgument("block_size argument to QR must be >= 1; got %d",
|
||||
block_size);
|
||||
}
|
||||
|
||||
Shape q_shape = a_shape;
|
||||
q_shape.mutable_dimensions().back() = m;
|
||||
|
||||
Shape qr_shape = ShapeUtil::MakeTupleShape({q_shape, a_shape});
|
||||
auto qr = CustomCall(a.builder(), "QrDecomposition", {a}, qr_shape);
|
||||
auto q = GetTupleElement(qr, 0);
|
||||
auto r = GetTupleElement(qr, 1);
|
||||
|
||||
// full_matrices is false when only a partial result in needed. Slice to the
|
||||
// needed dimensions here.
|
||||
if (!full_matrices) {
|
||||
auto qr = Qr(a);
|
||||
if (full_matrices) {
|
||||
XlaOp t;
|
||||
if (m < n) {
|
||||
t = SliceInMinorDims(qr.q_and_r, {0, 0}, {m, m});
|
||||
} else {
|
||||
t = PadInDim(qr.q_and_r, Zero(a.builder(), a_shape.element_type()),
|
||||
a_shape.dimensions_size() - 1, /*pad_lo=*/0,
|
||||
/*pad_hi=*/m - n);
|
||||
}
|
||||
q = ProductOfElementaryHouseholderReflectors(t, qr.taus);
|
||||
r = UpperTriangle(qr.q_and_r);
|
||||
} else {
|
||||
XlaOp t;
|
||||
if (m < n) {
|
||||
t = SliceInMinorDims(qr.q_and_r, {0, 0}, {m, m});
|
||||
} else {
|
||||
t = qr.q_and_r;
|
||||
}
|
||||
q = ProductOfElementaryHouseholderReflectors(t, qr.taus);
|
||||
q = SliceInMinorDims(q, {0, 0}, {m, p});
|
||||
r = SliceInMinorDims(r, {0, 0}, {p, n});
|
||||
r = UpperTriangle(SliceInMinorDims(qr.q_and_r, {0, 0}, {p, n}));
|
||||
}
|
||||
return QRDecompositionResult{q, r};
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
||||
@ -25,17 +25,27 @@ namespace xla {
|
||||
// given a (batched) matrix a, computes an orthonormal matrix Q and an
|
||||
// upper-triangular matrix R such that a = QR.
|
||||
// `a` must be a (batched) matrix of size [..., m, n].
|
||||
// The algorithm implements a blocked QR decomposition; `block_size` is
|
||||
// the block size to use.
|
||||
// TODO(phawkins): handle the complex case.
|
||||
struct QRDecompositionResult {
|
||||
XlaOp q;
|
||||
XlaOp r;
|
||||
struct QrDecomposition {
|
||||
// A matrix with the same shape as the input matrix `a`, whose upper triangle
|
||||
// (inclusive of the diagonal) is the matrix R, and whose lower triangle
|
||||
// (exclusive of the diagonal) contains the elementary Householder reflectors.
|
||||
// This is the same output format as used by LAPACK's xGEQRF routine.
|
||||
XlaOp q_and_r;
|
||||
// A vector of shape [..., min(m, n)] containing the scalar factors of the
|
||||
// elementary Householder reflectors.
|
||||
XlaOp taus;
|
||||
};
|
||||
|
||||
StatusOr<QRDecompositionResult> QRDecomposition(
|
||||
XlaOp a, bool full_matrices, int64 block_size = 128,
|
||||
PrecisionConfig::Precision precision = PrecisionConfig::HIGHEST);
|
||||
QrDecomposition Qr(XlaOp a);
|
||||
|
||||
// Given `a` and `taus` as returned by `QRDecomposition`, compute the product of
|
||||
// the elementary Householder reflectors (i.e., the matrix Q of the QR
|
||||
// decomposition). The equivalent LAPACK routine is xORGQR/xUNGQR.
|
||||
XlaOp ProductOfElementaryHouseholderReflectors(XlaOp a, XlaOp taus);
|
||||
|
||||
// Helper that combines `Qr` and `ProductOfElementaryHouseholderReflectors` to
|
||||
// compute explicit matrices `q` and `r`.
|
||||
void QrExplicit(XlaOp a, bool full_matrices, XlaOp& q, XlaOp& r);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
|
||||
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/array2d.h"
|
||||
#include "tensorflow/compiler/xla/array3d.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
@ -36,31 +37,44 @@ using QrTest = xla::ClientLibraryTestBase;
|
||||
XLA_TEST_F(QrTest, Simple) {
|
||||
// Test fails with TensorFloat-32 enabled
|
||||
tensorflow::enable_tensor_float_32_execution(false);
|
||||
xla::XlaBuilder builder(TestName());
|
||||
|
||||
xla::Array2D<float> a_vals({
|
||||
xla::Array2D<float> data({
|
||||
{4, 6, 8, 10},
|
||||
{6, 45, 54, 63},
|
||||
{8, 54, 146, 166},
|
||||
{10, 63, 166, 310},
|
||||
});
|
||||
|
||||
xla::XlaOp a;
|
||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto result,
|
||||
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2));
|
||||
for (bool full_matrices : {false, true}) {
|
||||
for (xla::int64 m : {3, 4}) {
|
||||
for (xla::int64 n : {3, 4}) {
|
||||
xla::XlaBuilder builder(TestName());
|
||||
xla::XlaOp a, q, r;
|
||||
xla::Array<float> a_vals = data.Slice({0, 0}, {m, n});
|
||||
auto a_data = CreateParameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
xla::QrExplicit(a, full_matrices, q, r);
|
||||
|
||||
// Verifies that the decomposition composes back to the original matrix.
|
||||
//
|
||||
// This isn't a terribly demanding test, (e.g., we should verify that Q is
|
||||
// orthonormal and R is upper-triangular) but it's awkward to write such tests
|
||||
// without more linear algebra libraries. It's easier to test the numerics
|
||||
// from Python, anyway, where we have access to numpy and scipy.
|
||||
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, a_vals, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
// Verifies that the decomposition composes back to the original matrix.
|
||||
//
|
||||
// This isn't a terribly demanding test, (e.g., we should verify that Q
|
||||
// is orthonormal and R is upper-triangular) but it's awkward to write
|
||||
// such tests without more linear algebra libraries. It's easier to test
|
||||
// the numerics from Python, anyway, where we have access to numpy and
|
||||
// scipy.
|
||||
xla::BatchDot(q, r, xla::PrecisionConfig::HIGHEST);
|
||||
TF_ASSERT_OK_AND_ASSIGN(xla::Shape q_shape, builder.GetShape(q));
|
||||
TF_ASSERT_OK_AND_ASSIGN(xla::Shape r_shape, builder.GetShape(r));
|
||||
EXPECT_EQ(q_shape,
|
||||
xla::ShapeUtil::MakeShape(
|
||||
xla::F32, {m, full_matrices ? m : std::min(m, n)}));
|
||||
EXPECT_EQ(r_shape,
|
||||
xla::ShapeUtil::MakeShape(
|
||||
xla::F32, {full_matrices ? m : std::min(m, n), n}));
|
||||
ComputeAndCompare<float>(&builder, a_vals, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
XLA_TEST_F(QrTest, ZeroDiagonal) {
|
||||
@ -74,11 +88,9 @@ XLA_TEST_F(QrTest, ZeroDiagonal) {
|
||||
{1, 1, 0},
|
||||
});
|
||||
|
||||
xla::XlaOp a;
|
||||
xla::XlaOp a, q, r;
|
||||
auto a_data = CreateR2Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto result,
|
||||
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/8));
|
||||
xla::QrExplicit(a, /*full_matrices=*/true, q, r);
|
||||
|
||||
// Verifies that the decomposition composes back to the original matrix.
|
||||
//
|
||||
@ -86,7 +98,7 @@ XLA_TEST_F(QrTest, ZeroDiagonal) {
|
||||
// orthonormal and R is upper-triangular) but it's awkward to write such tests
|
||||
// without more linear algebra libraries. It's easier to test the numerics
|
||||
// from Python, anyway, where we have access to numpy and scipy.
|
||||
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
|
||||
xla::BatchDot(q, r, xla::PrecisionConfig::HIGHEST);
|
||||
|
||||
ComputeAndCompareR2<float>(&builder, a_vals, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
@ -112,13 +124,11 @@ XLA_TEST_F(QrTest, SimpleBatched) {
|
||||
},
|
||||
});
|
||||
|
||||
xla::XlaOp a;
|
||||
xla::XlaOp a, q, r;
|
||||
auto a_data = CreateR3Parameter<float>(a_vals, 0, "a", &builder, &a);
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto result,
|
||||
xla::QRDecomposition(a, /*full_matrices=*/true, /*block_size=*/2));
|
||||
xla::QrExplicit(a, /*full_matrices=*/true, q, r);
|
||||
|
||||
xla::BatchDot(result.q, result.r, xla::PrecisionConfig::HIGHEST);
|
||||
xla::BatchDot(q, r, xla::PrecisionConfig::HIGHEST);
|
||||
|
||||
ComputeAndCompareR3<float>(&builder, a_vals, {a_data.get()},
|
||||
xla::ErrorSpec(1e-4, 1e-4));
|
||||
|
||||
@ -192,8 +192,9 @@ void BuildOpsSubmodule(py::module* m) {
|
||||
ops.def(
|
||||
"QR",
|
||||
[](XlaOp a, bool full_matrices) -> StatusOr<std::pair<XlaOp, XlaOp>> {
|
||||
TF_ASSIGN_OR_RETURN(auto qr, QRDecomposition(a, full_matrices));
|
||||
return std::make_pair(qr.q, qr.r);
|
||||
XlaOp q, r;
|
||||
QrExplicit(a, full_matrices, q, r);
|
||||
return std::make_pair(q, r);
|
||||
},
|
||||
py::arg("operand"), py::arg("full_matrices"));
|
||||
ops.def(
|
||||
|
||||
@ -1964,6 +1964,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client/lib:loops",
|
||||
"//tensorflow/compiler/xla/client/lib:math",
|
||||
"//tensorflow/compiler/xla/client/lib:matrix",
|
||||
"//tensorflow/compiler/xla/client/lib:qr",
|
||||
"//tensorflow/compiler/xla/client/lib:slicing",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
|
||||
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/loops.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
@ -172,7 +173,7 @@ Status House(XlaOp x, XlaOp k, absl::Span<const int64> batch_dims,
|
||||
// a[j+1:, j] = v[j+1:]
|
||||
// taus[j] = tau
|
||||
// return (a, taus)
|
||||
StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
|
||||
StatusOr<QrDecomposition> QrExpander::QrBlock(
|
||||
XlaOp a, PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
@ -269,8 +270,8 @@ StatusOr<QrExpander::QrResult> QrExpander::QrBlock(
|
||||
TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn,
|
||||
{a, taus}, "qr", builder));
|
||||
|
||||
QrResult result;
|
||||
result.a = values[0];
|
||||
QrDecomposition result;
|
||||
result.q_and_r = values[0];
|
||||
result.taus = values[1];
|
||||
return result;
|
||||
}
|
||||
@ -372,18 +373,21 @@ StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
|
||||
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
|
||||
}
|
||||
|
||||
auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
|
||||
std::vector<int64> taus_dims = batch_dims;
|
||||
taus_dims.push_back(p);
|
||||
auto taus = Zeros(builder, ShapeUtil::MakeShape(type, taus_dims));
|
||||
for (int64 i = 0; i < p; i += block_size) {
|
||||
int64 k = std::min(block_size, p - i);
|
||||
|
||||
auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
|
||||
TF_ASSIGN_OR_RETURN(auto qr_block, QrBlock(a_block, precision));
|
||||
auto y = Add(
|
||||
IdentityMatrix(builder, type, m - i, k),
|
||||
Select(TriangleMask(qr_block.a, -1), qr_block.a, ZerosLike(qr_block.a)),
|
||||
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
|
||||
auto y = Add(IdentityMatrix(builder, type, m - i, k),
|
||||
Select(TriangleMask(qr_block.q_and_r, -1), qr_block.q_and_r,
|
||||
ZerosLike(qr_block.q_and_r)),
|
||||
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
|
||||
|
||||
a = UpdateSliceInMinorDims(a, qr_block.a, {i, i});
|
||||
a = UpdateSliceInMinorDims(a, qr_block.q_and_r, {i, i});
|
||||
taus = UpdateSliceInMinorDims(taus, qr_block.taus, {i});
|
||||
|
||||
// Compute the I + Y @ T @ Y^t block representation of a product of
|
||||
// Householder matrices.
|
||||
@ -401,8 +405,64 @@ StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
|
||||
a_update = BatchDot(yt, a_update, precision);
|
||||
a_panel = a_panel + a_update;
|
||||
a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
|
||||
}
|
||||
|
||||
return Tuple(builder, {a, taus});
|
||||
}
|
||||
|
||||
StatusOr<XlaOp> QrExpander::ProductOfElementaryHouseholderReflectors(
|
||||
XlaOp a, XlaOp taus, int64 block_size,
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
TF_ASSIGN_OR_RETURN(Shape taus_shape, builder->GetShape(taus));
|
||||
const int num_dims = a_shape.rank();
|
||||
if (num_dims < 2) {
|
||||
return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s",
|
||||
a_shape.ToString());
|
||||
}
|
||||
PrimitiveType type = a_shape.element_type();
|
||||
|
||||
const int64 m = ShapeUtil::GetDimension(a_shape, -2);
|
||||
int64 n = ShapeUtil::GetDimension(a_shape, -1);
|
||||
const int64 p = ShapeUtil::GetDimension(taus_shape, -1);
|
||||
if (m < n) {
|
||||
return InvalidArgument(
|
||||
"Argument to product of elementary Householder "
|
||||
"reflectors must have m >= n, got shape %s",
|
||||
a_shape.ToString());
|
||||
}
|
||||
|
||||
if (block_size < 1) {
|
||||
return InvalidArgument("block_size argument to QR must be >= 1; got %d",
|
||||
block_size);
|
||||
}
|
||||
|
||||
const int64 num_batch_dims = num_dims - 2;
|
||||
std::vector<int64> batch_dims(num_batch_dims);
|
||||
for (int i = 0; i < num_batch_dims; ++i) {
|
||||
batch_dims[i] = ShapeUtil::GetDimension(a_shape, i);
|
||||
}
|
||||
|
||||
auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims);
|
||||
for (int64 i = 0; i < p; i += block_size) {
|
||||
int64 k = std::min(block_size, p - i);
|
||||
|
||||
auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k});
|
||||
auto y = Add(IdentityMatrix(builder, type, m - i, k),
|
||||
Select(TriangleMask(a_block, -1), a_block, ZerosLike(a_block)),
|
||||
/*broadcast_dimensions=*/{num_dims - 2, num_dims - 1});
|
||||
|
||||
// Compute the I + Y @ T @ Y^t block representation of a product of
|
||||
// Householder matrices.
|
||||
auto taus_block = SliceInMinorDims(taus, {i}, {i + k});
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto t, CompactWYRepresentation(type, batch_dims, y, taus_block, m - i,
|
||||
k, precision));
|
||||
// q[:, i:] += (q[:, i:] @ y) @ np.conj((y @ np.conj(t.T)).T)
|
||||
auto yt = BatchDot(y, /*transpose_x=*/false, MaybeConjugate(t, true),
|
||||
/*transpose_y=*/true, precision);
|
||||
auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
|
||||
auto q_update = BatchDot(q_panel, y, precision);
|
||||
q_update =
|
||||
@ -411,19 +471,26 @@ StatusOr<XlaOp> QrExpander::BuildQrDecomposition(
|
||||
q_panel = q_panel + q_update;
|
||||
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
|
||||
}
|
||||
|
||||
return Tuple(builder, {q, UpperTriangle(a)});
|
||||
q = SliceInMinorDims(q, {0, 0}, {m, n});
|
||||
return q;
|
||||
}
|
||||
|
||||
static const char* kQrCustomCallName = "Qr";
|
||||
static const char* kHouseholderProductCustomCallName =
|
||||
"ProductOfElementaryHouseholderReflectors";
|
||||
|
||||
bool QrExpander::InstructionMatchesPattern(HloInstruction* instruction) {
|
||||
return instruction->opcode() == HloOpcode::kCustomCall &&
|
||||
instruction->custom_call_target() == "QrDecomposition";
|
||||
(instruction->custom_call_target() == kQrCustomCallName ||
|
||||
instruction->custom_call_target() ==
|
||||
kHouseholderProductCustomCallName);
|
||||
}
|
||||
|
||||
StatusOr<HloInstruction*> QrExpander::ExpandInstruction(
|
||||
HloInstruction* instruction) {
|
||||
const string name =
|
||||
absl::StrFormat("xla.qr_%s", instruction->operand(0)->shape().ToString());
|
||||
absl::StrFormat("xla.%s_%s", instruction->custom_call_target(),
|
||||
instruction->operand(0)->shape().ToString());
|
||||
|
||||
HloModule* module = instruction->parent()->parent();
|
||||
|
||||
@ -441,13 +508,25 @@ StatusOr<HloInstruction*> QrExpander::ExpandInstruction(
|
||||
// into our HloModule. Ideally we would avoid the protocol buffer step;
|
||||
// that is left as an exercise for future work.
|
||||
XlaBuilder builder(name);
|
||||
TF_RET_CHECK(instruction->operand_count() >= 1);
|
||||
XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
XlaOp l, BuildQrDecomposition(a,
|
||||
/*block_size=*/128,
|
||||
XlaOp result;
|
||||
if (instruction->custom_call_target() == kQrCustomCallName) {
|
||||
TF_RET_CHECK(instruction->operand_count() == 1);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
result, BuildQrDecomposition(a,
|
||||
/*block_size=*/128,
|
||||
/*precision=*/PrecisionConfig::HIGHEST));
|
||||
} else {
|
||||
TF_RET_CHECK(instruction->operand_count() == 2);
|
||||
XlaOp taus =
|
||||
Parameter(&builder, 1, instruction->operand(1)->shape(), "taus");
|
||||
TF_ASSIGN_OR_RETURN(result, ProductOfElementaryHouseholderReflectors(
|
||||
a, taus, /*block_size=*/128,
|
||||
/*precision=*/PrecisionConfig::HIGHEST));
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(l));
|
||||
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(result));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
xla_computation.GetProgramShape());
|
||||
|
||||
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_QR_EXPANDER_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/qr.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||
|
||||
@ -32,17 +33,8 @@ class QrExpander : public OpExpanderPass {
|
||||
StatusOr<HloInstruction*> ExpandInstruction(
|
||||
HloInstruction* instruction) override;
|
||||
|
||||
struct QrResult {
|
||||
// The upper-triangular matrix R, packed together with the lower-triangular
|
||||
// elementary Householder reflectors `vs` below the diagonal.
|
||||
XlaOp a;
|
||||
|
||||
// Representation of the Householder matrices I - beta v v.T
|
||||
XlaOp taus; // Shape: [..., min(m, n)]
|
||||
};
|
||||
|
||||
virtual StatusOr<QrResult> QrBlock(XlaOp a,
|
||||
PrecisionConfig::Precision precision);
|
||||
virtual StatusOr<QrDecomposition> QrBlock(
|
||||
XlaOp a, PrecisionConfig::Precision precision);
|
||||
|
||||
virtual StatusOr<XlaOp> CompactWYRepresentation(
|
||||
PrimitiveType type, absl::Span<const int64> batch_dims, XlaOp vs,
|
||||
@ -52,6 +44,10 @@ class QrExpander : public OpExpanderPass {
|
||||
StatusOr<XlaOp> BuildQrDecomposition(XlaOp a, int64 block_size,
|
||||
PrecisionConfig::Precision precision);
|
||||
|
||||
StatusOr<XlaOp> ProductOfElementaryHouseholderReflectors(
|
||||
XlaOp a, XlaOp taus, int64 block_size,
|
||||
PrecisionConfig::Precision precision);
|
||||
|
||||
// Mapping from op signatures to existing computations.
|
||||
absl::flat_hash_map<string, HloComputation*> computation_cache_;
|
||||
};
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user