[XLA] Use the compact WY representation in the implementation of blocked QR decompositions.
Compact WY representations are described in: Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY representation for products of Householder transformations." SIAM Journal on Scientific and Statistical Computing 10.1 (1989): 53-57. The compact WY representation is more storage efficient, requiring calculation of an nxn triangular matrix, where n is the block size (e.g., 128), instead of an mxn matrix where m is the number of matrix rows. PiperOrigin-RevId: 330711085 Change-Id: Ideac239ff118ee6ac2fd1397b731a40e11d6ecd7
This commit is contained in:
parent
1f50fde327
commit
dc5718967e
@ -258,33 +258,36 @@ StatusOr<QRBlockResult> QRBlock(XlaOp a, PrecisionConfig::Precision precision) {
|
||||
return result;
|
||||
}
|
||||
|
||||
// Computes W and Y such that I-WY is equivalent to the sequence of Householder
|
||||
// transformations given by vs and taus.
|
||||
// Golub and van Loan, "Matrix Computations", algorithm 5.1.2.
|
||||
// Computes T such that (I - Y @ T @ Y^t) is a product of the elementary
|
||||
// Householder reflectors given by `vs` and `taus`.
|
||||
//
|
||||
// Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY
|
||||
// representation for products of Householder transformations." SIAM Journal on
|
||||
// Scientific and Statistical Computing 10.1 (1989): 53-57.
|
||||
//
|
||||
// m, n = vs.shape[-2:]
|
||||
// t = np.zeros((n, n))
|
||||
// Y = np.zeros([m, n])
|
||||
// W = np.zeros([m, n])
|
||||
// t[0, 0] = -taus[0]
|
||||
// Y[:, 0] = vs[:, 0]
|
||||
// W[:, 0] = -taus[0] * vs[:, 0]
|
||||
// for j in xrange(1, n):
|
||||
// v = vs[:, j]
|
||||
// z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v))
|
||||
// W[:, j] = z
|
||||
// Y[:, j] = v
|
||||
// return W
|
||||
// There is no need to return Y since at termination of the loop it is equal to
|
||||
// vs.
|
||||
StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
|
||||
// for i in range(1, n):
|
||||
// z = -taus[i] * np.dot(t, np.dot(Y.T, vs[:, i]))
|
||||
// Y[:, i] = vs[:, i]
|
||||
// t[:i, i] = z[:i]
|
||||
// t[i, i] = -taus[i]
|
||||
StatusOr<XlaOp> CompactWYRepresentation(PrimitiveType type,
|
||||
absl::Span<const int64> batch_dims,
|
||||
XlaOp vs, XlaOp taus, int64 m, int64 n,
|
||||
PrecisionConfig::Precision precision) {
|
||||
std::vector<int64> batch_dim_indices(batch_dims.size());
|
||||
std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0);
|
||||
int64 m_index = batch_dims.size();
|
||||
int64 n_index = batch_dims.size() + 1;
|
||||
|
||||
auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values,
|
||||
XlaBuilder* builder) -> StatusOr<std::vector<XlaOp>> {
|
||||
// w has shape [..., m, n]
|
||||
auto w = values[0];
|
||||
auto t = values[0];
|
||||
const auto vs = values[1];
|
||||
const auto taus = values[2];
|
||||
|
||||
@ -303,31 +306,37 @@ StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
|
||||
auto y = Select(Ge(iota_mn, j), ZerosLike(vs), vs);
|
||||
|
||||
// yv has shape [..., n, 1]
|
||||
auto yv = BatchDot(y, true, v, false, precision);
|
||||
// wyv has shape [..., m, 1]
|
||||
auto wyv = BatchDot(w, yv, precision);
|
||||
auto yv =
|
||||
BatchDot(y, /*transpose_x=*/true, v, /*transpose_y=*/false, precision);
|
||||
// wyv has shape [..., n, 1]
|
||||
auto wyv = BatchDot(t, yv, precision);
|
||||
|
||||
auto z = Mul(
|
||||
-beta, v + wyv,
|
||||
-beta, wyv,
|
||||
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
|
||||
beta = BroadcastInDim(beta, ConcatVectors(batch_dims, {n, 1}),
|
||||
ConcatVectors(batch_dim_indices, {n_index}));
|
||||
auto iota_n = Iota(
|
||||
builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n, 1})),
|
||||
m_index);
|
||||
|
||||
w = DynamicUpdateSliceInMinorDims(w, z, {j});
|
||||
z = Select(Lt(iota_n, j), z, Select(Eq(iota_n, j), -beta, ZerosLike(beta)));
|
||||
|
||||
return std::vector<XlaOp>{w, vs, taus};
|
||||
t = DynamicUpdateSliceInMinorDims(t, z, {j});
|
||||
|
||||
return std::vector<XlaOp>{t, vs, taus};
|
||||
};
|
||||
|
||||
XlaBuilder* builder = vs.builder();
|
||||
auto w = Zeros(builder,
|
||||
ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n})));
|
||||
auto v = SliceInMinorDims(vs, {0}, {1});
|
||||
auto t = Zeros(builder,
|
||||
ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n, n})));
|
||||
auto beta = SliceInMinorDims(taus, {0}, {1});
|
||||
auto bv =
|
||||
Mul(-beta, v,
|
||||
/*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index}));
|
||||
w = UpdateSliceInMinorDims(w, bv, {0});
|
||||
beta = BroadcastInDim(beta, ConcatVectors(batch_dims, {1, 1}),
|
||||
ConcatVectors(batch_dim_indices, {n_index}));
|
||||
t = UpdateSliceInMinorDims(t, -beta, {0});
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(n - 1, S32, body_fn,
|
||||
{w, vs, taus}, "wy", builder));
|
||||
{t, vs, taus}, "wy", builder));
|
||||
return values[0];
|
||||
}
|
||||
|
||||
@ -342,12 +351,10 @@ StatusOr<XlaOp> ComputeWYRepresentation(PrimitiveType type,
|
||||
// k = min(block_size, min(m, n) - s)
|
||||
// (a, vs, taus) = qr(a[i:, i:i+k])
|
||||
// y = vs
|
||||
// w = ComputeWYRepresentation(vs, taus, m-i, k)
|
||||
// a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:]))
|
||||
// q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T))
|
||||
// t = CompactWYRepresentation(vs, taus, m-i, k)
|
||||
// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
|
||||
// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
|
||||
// return (q, a)
|
||||
// TODO(phawkins): consider using UT transformations (in the form I - V U V')
|
||||
// rather than WY transformations.
|
||||
StatusOr<QRDecompositionResult> QRDecomposition(
|
||||
XlaOp a, bool full_matrices, int64 block_size,
|
||||
PrecisionConfig::Precision precision) {
|
||||
@ -384,24 +391,28 @@ StatusOr<QRDecompositionResult> QRDecomposition(
|
||||
|
||||
a = UpdateSliceInMinorDims(a, qr_block.r, {i, i});
|
||||
|
||||
// Compute the I-WY block representation of a product of Householder
|
||||
// matrices.
|
||||
// Compute the I + Y @ T @ Y^t block representation of a product of
|
||||
// Householder matrices.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs,
|
||||
auto t, CompactWYRepresentation(type, batch_dims, qr_block.vs,
|
||||
qr_block.taus, m - i, k, precision));
|
||||
auto y = qr_block.vs;
|
||||
|
||||
// a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:]))
|
||||
// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:])
|
||||
auto yt =
|
||||
BatchDot(y, /*transpose_x=*/false, t, /*transpose_y=*/true, precision);
|
||||
auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n});
|
||||
auto a_update = BatchDot(w, true, a_panel, false, precision);
|
||||
a_update = BatchDot(y, a_update, precision);
|
||||
auto a_update = BatchDot(y, /*transpose_x=*/true, a_panel,
|
||||
/*transpose_y=*/false, precision);
|
||||
a_update = BatchDot(yt, a_update, precision);
|
||||
a_panel = a_panel + a_update;
|
||||
a = UpdateSliceInMinorDims(a, a_panel, {i, i + k});
|
||||
|
||||
// q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T))
|
||||
// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T
|
||||
auto q_panel = SliceInMinorDims(q, {0, i}, {m, m});
|
||||
auto q_update = BatchDot(q_panel, w, precision);
|
||||
q_update = BatchDot(q_update, false, y, true, precision);
|
||||
auto q_update = BatchDot(q_panel, y, precision);
|
||||
q_update = BatchDot(q_update, /*transpose_x=*/false, yt,
|
||||
/*transpose_y=*/true, precision);
|
||||
q_panel = q_panel + q_update;
|
||||
q = UpdateSliceInMinorDims(q, q_panel, {0, i});
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user