[XLA] Use the compact WY representation in the implementation of blocked QR decompositions.

Compact WY representations are described in:
Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY representation for products of Householder transformations." SIAM Journal on Scientific and Statistical Computing 10.1 (1989): 53-57.

The compact WY representation is more storage efficient, requiring calculation of an nxn triangular matrix, where n is the block size (e.g., 128), instead of an mxn matrix where m is the number of matrix rows.

PiperOrigin-RevId: 330711085
Change-Id: Ideac239ff118ee6ac2fd1397b731a40e11d6ecd7
This commit is contained in:
Peter Hawkins 2020-09-09 06:35:02 -07:00 committed by TensorFlower Gardener
parent 1f50fde327
commit dc5718967e

View File

@ -258,33 +258,36 @@ StatusOr<QRBlockResult> QRBlock(XlaOp a, PrecisionConfig::Precision precision) {
return result;
}
// Computes W and Y such that I-WY is equivalent to the sequence of Householder
// transformations given by vs and taus.
// Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
// Computes T such that (I - Y @ T @ Y^t) is a product of the elementary
// Householder reflectors given by `vs` and `taus`.
//
// Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY
// representation for products of Householder transformations." SIAM Journal on
// Scientific and Statistical Computing 10.1 (1989): 53-57.
//
// m, n = vs.shape[-2:]
// t = np.zeros((n, n))
// Y = np.zeros([m, n])
// W = np.zeros([m, n])
// t[0, 0] = -taus[0]
// Y[:, 0] = vs[:, 0]
// W[:, 0] = -taus[0] * vs[:, 0]
// for j in xrange(1, n):
// v = vs[:, j]
// z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
// W[:, j] = z
// Y[:, j] = v
// return W
// There is no need to return Y since at termination of the loop it is equal to
// vs.
StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
// for i in range(1, n):
// z = -taus[i] * np.dot(t, np.dot(Y.T, vs[:, i]))
// Y[:, i] = vs[:, i]
// t[:i, i] = z[:i]
// t[i, i] = -taus[i]
StatusOr<XlaOp> CompactWYRepresentation(PrimitiveType type,
absl::Span<const int64> batch_dims,
XlaOp vs, XlaOp taus, int64 m, int64 n,
PrecisionConfig::Precision precision) {
std::vector<int64> batch_dim_indices(batch_dims.size());
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
int64 m_index = batch_dims.size();
int64 n_index = batch_dims.size() + 1;
auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
// w has shape [..., m, n]
auto w = values[0];
auto t = values[0];
const auto vs = values[1];
const auto taus = values[2];
@ -303,31 +306,37 @@ StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
auto y = Select(Ge(iota_mn, j), ZerosLike(vs), vs);
// yv has shape [..., n, 1]
auto yv = BatchDot(y, true, v, false, precision);
// wyv has shape [..., m, 1]
auto wyv = BatchDot(w, yv, precision);
auto yv =
BatchDot(y, /*transpose_x=*/true, v, /*transpose_y=*/false, precision);
// wyv has shape [..., n, 1]
auto wyv = BatchDot(t, yv, precision);
auto z = Mul(
-beta, v + wyv,
-beta, wyv,
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
beta = BroadcastInDim(beta, ConcatVectors(batch_dims, {n, 1}),
ConcatVectors(batch_dim_indices, {n_index}));
auto iota_n = Iota(
builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n, 1})),
m_index);
w = DynamicUpdateSliceInMinorDims(w, z, {j});
z = Select(Lt(iota_n, j), z, Select(Eq(iota_n, j), -beta, ZerosLike(beta)));
return std::vector<XlaOp>{w, vs, taus};
t = DynamicUpdateSliceInMinorDims(t, z, {j});
return std::vector<XlaOp>{t, vs, taus};
};
XlaBuilder* builder = vs.builder();
auto w = Zeros(builder,
ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
auto v = SliceInMinorDims(vs, {0}, {1});
auto t = Zeros(builder,
ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n, n})));
auto beta = SliceInMinorDims(taus, {0}, {1});
auto bv =
Mul(-beta, v,
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
w = UpdateSliceInMinorDims(w, bv, {0});
beta = BroadcastInDim(beta, ConcatVectors(batch_dims, {1, 1}),
ConcatVectors(batch_dim_indices, {n_index}));
t = UpdateSliceInMinorDims(t, -beta, {0});
TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(n - 1, S32, body_fn,
{w, vs, taus}, "wy", builder));
{t, vs, taus}, "wy", builder));
return values[0];
}
@ -342,12 +351,10 @@ StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
// k = min(block_size, min(m, n) - s)
// (a, vs, taus) = qr(a[i:, i:i+k])
// y = vs
// w = ComputeWYRepresentation(vs, taus, m-i, k)
// a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
// q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
// t = CompactWYRepresentation(vs, taus, m-i, k)
// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
// return (q, a)
// TODO(phawkins): consider using UT transformations (in the form I - V U V')
// rather than WY transformations.
StatusOr<QRDecompositionResult> QRDecomposition(
XlaOp a, bool full_matrices, int64 block_size,
PrecisionConfig::Precision precision) {
@ -384,24 +391,28 @@ StatusOr<QRDecompositionResult> QRDecomposition(
a = UpdateSliceInMinorDims(a, qr_block.r, {i, i});
// Compute the I-WY block representation of a product of Householder
// matrices.
// Compute the I + Y @ T @ Y^t block representation of a product of
// Householder matrices.
TF_ASSIGN_OR_RETURN(
auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs,
auto t, CompactWYRepresentation(type, batch_dims, qr_block.vs,
qr_block.taus, m - i, k, precision));
auto y = qr_block.vs;
// a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
auto yt =
BatchDot(y, /*transpose_x=*/false, t, /*transpose_y=*/true, precision);
auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
auto a_update = BatchDot(w, true, a_panel, false, precision);
a_update = BatchDot(y, a_update, precision);
auto a_update = BatchDot(y, /*transpose_x=*/true, a_panel,
/*transpose_y=*/false, precision);
a_update = BatchDot(yt, a_update, precision);
a_panel = a_panel + a_update;
a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
// q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
auto q_update = BatchDot(q_panel, w, precision);
q_update = BatchDot(q_update, false, y, true, precision);
auto q_update = BatchDot(q_panel, y, precision);
q_update = BatchDot(q_update, /*transpose_x=*/false, yt,
/*transpose_y=*/true, precision);
q_panel = q_panel + q_update;
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
}