diff --git a/tensorflow/compiler/xla/client/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc index b2eecbac309..8ef1f1fb351 100644 --- a/tensorflow/compiler/xla/client/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -258,33 +258,36 @@ StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { return result; } -// Computes W and Y such that I-WY is equivalent to the sequence of Householder -// transformations given by vs and taus. -// Golub and van Loan, "Matrix Computations", algorithm 5.1.2. +// Computes T such that (I - Y @ T @ Y^t) is a product of the elementary +// Householder reflectors given by `vs` and `taus`. +// +// Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY +// representation for products of Householder transformations." SIAM Journal on +// Scientific and Statistical Computing 10.1 (1989): 53-57. +// +// m, n = vs.shape[-2:] +// t = np.zeros((n, n)) // Y = np.zeros([m, n]) -// W = np.zeros([m, n]) +// t[0, 0] = -taus[0] // Y[:, 0] = vs[:, 0] -// W[:, 0] = -taus[0] * vs[:, 0] -// for j in xrange(1, n): -// v = vs[:, j] -// z = -taus[j] * v - taus[j] * np.dot(W, np.dot(Y.T, v)) -// W[:, j] = z -// Y[:, j] = v -// return W -// There is no need to return Y since at termination of the loop it is equal to -// vs. -StatusOr ComputeWYRepresentation(PrimitiveType type, +// for i in range(1, n): +// z = -taus[i] * np.dot(t, np.dot(Y.T, vs[:, i])) +// Y[:, i] = vs[:, i] +// t[:i, i] = z[:i] +// t[i, i] = -taus[i] +StatusOr CompactWYRepresentation(PrimitiveType type, absl::Span batch_dims, XlaOp vs, XlaOp taus, int64 m, int64 n, PrecisionConfig::Precision precision) { std::vector batch_dim_indices(batch_dims.size()); std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); + int64 m_index = batch_dims.size(); int64 n_index = batch_dims.size() + 1; auto body_fn = [&](XlaOp j, absl::Span values, XlaBuilder* builder) -> StatusOr> { // w has shape [..., m, n] - auto w = values[0]; + auto t = values[0]; const auto vs = values[1]; const auto taus = values[2]; @@ -303,31 +306,37 @@ StatusOr ComputeWYRepresentation(PrimitiveType type, auto y = Select(Ge(iota_mn, j), ZerosLike(vs), vs); // yv has shape [..., n, 1] - auto yv = BatchDot(y, true, v, false, precision); - // wyv has shape [..., m, 1] - auto wyv = BatchDot(w, yv, precision); + auto yv = + BatchDot(y, /*transpose_x=*/true, v, /*transpose_y=*/false, precision); + // wyv has shape [..., n, 1] + auto wyv = BatchDot(t, yv, precision); auto z = Mul( - -beta, v + wyv, + -beta, wyv, /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); + beta = BroadcastInDim(beta, ConcatVectors(batch_dims, {n, 1}), + ConcatVectors(batch_dim_indices, {n_index})); + auto iota_n = Iota( + builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n, 1})), + m_index); - w = DynamicUpdateSliceInMinorDims(w, z, {j}); + z = Select(Lt(iota_n, j), z, Select(Eq(iota_n, j), -beta, ZerosLike(beta))); - return std::vector{w, vs, taus}; + t = DynamicUpdateSliceInMinorDims(t, z, {j}); + + return std::vector{t, vs, taus}; }; XlaBuilder* builder = vs.builder(); - auto w = Zeros(builder, - ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {m, n}))); - auto v = SliceInMinorDims(vs, {0}, {1}); + auto t = Zeros(builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n, n}))); auto beta = SliceInMinorDims(taus, {0}, {1}); - auto bv = - Mul(-beta, v, - /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {n_index})); - w = UpdateSliceInMinorDims(w, bv, {0}); + beta = BroadcastInDim(beta, ConcatVectors(batch_dims, {1, 1}), + ConcatVectors(batch_dim_indices, {n_index})); + t = UpdateSliceInMinorDims(t, -beta, {0}); TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(n - 1, S32, body_fn, - {w, vs, taus}, "wy", builder)); + {t, vs, taus}, "wy", builder)); return values[0]; } @@ -342,12 +351,10 @@ StatusOr ComputeWYRepresentation(PrimitiveType type, // k = min(block_size, min(m, n) - s) // (a, vs, taus) = qr(a[i:, i:i+k]) // y = vs -// w = ComputeWYRepresentation(vs, taus, m-i, k) -// a[i:, i+r:] += np.dot(y, np.dot(w.T, a[i:, i+k:])) -// q[:, i:] += np.dot(q[:, i:], np.dot(w, y.T)) +// t = CompactWYRepresentation(vs, taus, m-i, k) +// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:]) +// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T // return (q, a) -// TODO(phawkins): consider using UT transformations (in the form I - V U V') -// rather than WY transformations. StatusOr QRDecomposition( XlaOp a, bool full_matrices, int64 block_size, PrecisionConfig::Precision precision) { @@ -384,24 +391,28 @@ StatusOr QRDecomposition( a = UpdateSliceInMinorDims(a, qr_block.r, {i, i}); - // Compute the I-WY block representation of a product of Householder - // matrices. + // Compute the I + Y @ T @ Y^t block representation of a product of + // Householder matrices. TF_ASSIGN_OR_RETURN( - auto w, ComputeWYRepresentation(type, batch_dims, qr_block.vs, + auto t, CompactWYRepresentation(type, batch_dims, qr_block.vs, qr_block.taus, m - i, k, precision)); auto y = qr_block.vs; - // a[i:, i+k:] += np.dot(Y, np.dot(W.T, a[i:, i+k:])) + // a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:]) + auto yt = + BatchDot(y, /*transpose_x=*/false, t, /*transpose_y=*/true, precision); auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = BatchDot(w, true, a_panel, false, precision); - a_update = BatchDot(y, a_update, precision); + auto a_update = BatchDot(y, /*transpose_x=*/true, a_panel, + /*transpose_y=*/false, precision); + a_update = BatchDot(yt, a_update, precision); a_panel = a_panel + a_update; a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); - // q[:, i:] += np.dot(np.dot(q[:, i:], W), Y.T)) + // q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = BatchDot(q_panel, w, precision); - q_update = BatchDot(q_update, false, y, true, precision); + auto q_update = BatchDot(q_panel, y, precision); + q_update = BatchDot(q_update, /*transpose_x=*/false, yt, + /*transpose_y=*/true, precision); q_panel = q_panel + q_update; q = UpdateSliceInMinorDims(q, q_panel, {0, i}); }