From 1214a3035e078cab2627369d2a0fd81194cf8f09 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 21 Sep 2020 17:25:52 -0700 Subject: [PATCH] [XLA] Change the QR decomposition implementation to lower to a CustomCall operator that is later expanded by a new QrExpander pass. This allows backends to lower the Qr decomposition differently if they choose. PiperOrigin-RevId: 332972250 Change-Id: I24f9c7ed02a2a207e052cb713947baa31d909425 --- tensorflow/compiler/xla/client/lib/qr.cc | 327 +------------ tensorflow/compiler/xla/service/BUILD | 23 + tensorflow/compiler/xla/service/cpu/BUILD | 1 + .../compiler/xla/service/cpu/cpu_compiler.cc | 2 + tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../compiler/xla/service/gpu/gpu_compiler.cc | 3 + .../compiler/xla/service/interpreter/BUILD | 1 + .../xla/service/interpreter/compiler.cc | 2 + .../compiler/xla/service/qr_expander.cc | 434 ++++++++++++++++++ tensorflow/compiler/xla/service/qr_expander.h | 61 +++ 10 files changed, 536 insertions(+), 319 deletions(-) create mode 100644 tensorflow/compiler/xla/service/qr_expander.cc create mode 100644 tensorflow/compiler/xla/service/qr_expander.h diff --git a/tensorflow/compiler/xla/client/lib/qr.cc b/tensorflow/compiler/xla/client/lib/qr.cc index 09fa465a865..88a17f94a24 100644 --- a/tensorflow/compiler/xla/client/lib/qr.cc +++ b/tensorflow/compiler/xla/client/lib/qr.cc @@ -35,277 +35,7 @@ namespace xla { namespace { -std::vector ConcatVectors(absl::Span xs, - absl::Span ys) { - std::vector output(xs.size() + ys.size()); - std::copy(xs.begin(), xs.end(), output.begin()); - std::copy(ys.begin(), ys.end(), output.begin() + xs.size()); - return output; -} -// Computes a Householder reflection of the form: -// H = I - tau v v.T. -// such that -// H . ( x1 ) = ( x1 ) -// ( x2 ) = ( x2 ) -// ( ... ) = ( ... ) -// ( xk ) = ( beta ) -// ( ... ) ( 0 ) -// ( ... ) ( 0 ) -// Unlike the usual formulation, we allow the caller to supply 'k' rather than -// only providing the relevant part of 'x' to maintain XLA's static shape -// invariant. In addition, the implementation supports batching. -// Pseudo-code, without batching: -// alpha = x[k] -// x_copy = np.copy(x) -// x_copy[:k+1] = 0 -// xnorm = norm2(x_copy) -// if xnorm == 0: -// beta = alpha -// tau = 0 -// v = np.zeros_like(x) -// else: -// beta = - np.sign(alpha) * dlapy2(alpha, xnorm) -// tau = (beta - alpha) / beta -// v = x / (alpha - beta) -// v[k] = 1 -// return (v, tau, beta) -// TODO(phawkins): LAPACK's xLARFG implementation has code for handling -// overflows in the norm/beta calculations. Perhaps do the same here. -Status House(XlaOp x, XlaOp k, absl::Span batch_dims, - const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) { - XlaBuilder* const builder = x.builder(); - TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); - const PrimitiveType type = x_shape.element_type(); - - std::vector batch_dim_ids(batch_dims.size()); - std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); - const int64 minor_dim = batch_dims.size(); - - XlaOp zero = ScalarLike(x, 0.0); - XlaOp one = ScalarLike(x, 1.0); - - // alpha = x[k] - XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); - - // Compute x[k+1:] (padded with zeros in elements 0..k) - XlaOp iota = Iota(builder, S32, m); - XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type), - /*broadcast_dimensions=*/{minor_dim}); - - // sigma = np.dot(x[k+1:], x[k+1:]) - auto sigma = Reduce(x_after_k * x_after_k, zero, - CreateScalarAddComputation(type, builder), {minor_dim}); - // mu = np.sqrt(x[k]*x[k] + sigma) - auto mu = Sqrt(Square(alpha) + sigma); - - auto sigma_is_zero = Eq(sigma, zero); - - *beta = Select(sigma_is_zero, alpha, Select(Lt(alpha, zero), one, -one) * mu); - *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims), - (*beta - alpha) / *beta); - auto divisor = - Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta); - - auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type), - std::vector(batch_dims.size(), 1)); - - // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor - // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. - *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); - return Status::OK(); -} - -// Householder QR decomposition. Algorithm 5.2.1 from Golub and Van -// Loan "Matrix Computations", 4th Edition. This is an unblocked implementation -// used as an inner routine of the blocked implementation. -// Algorithm is adapted slightly so the shapes inside the loop are static, at -// the cost of some redundant computation. Since this is used as an inner block -// kernel, accumulates the Householder transformations (vs, taus) rather than -// the matrix q. -// Equivalent Python code, without batching: -// def qr(a): -// m = a.shape[0] -// n = a.shape[1] -// taus = np.zeros([n]) -// for j in xrange(min(m, n)): -// v, tau, beta = house(a[:, j], j) -// a[:, j+1:] -= tau * np.dot(v[:, np.newaxis], -// np.dot(v[np.newaxis, :], a[:, j+1:])) -// # Form column j explicitly rather than relying on the precision of the -// # Householder update. -// a[j, j] = beta -// a[j+1:, j] = v[j+1:] -// taus[j] = tau -// return (a, taus) -struct QRBlockResult { - // The upper-triangular matrix R, packed together with the lower-triangular - // elementary Householder reflectors `vs` below the diagonal. - XlaOp a; - - // Representation of the Householder matrices I - beta v v.T - XlaOp taus; // Shape: [..., n] -}; -StatusOr QRBlock(XlaOp a, PrecisionConfig::Precision precision) { - XlaBuilder* builder = a.builder(); - TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); - const int num_dims = a_shape.rank(); - if (num_dims < 2) { - return InvalidArgument("Argument to QR must have rank >= 2; got shape %s", - a_shape.ToString()); - } - PrimitiveType type = a_shape.element_type(); - - const int64 m = ShapeUtil::GetDimension(a_shape, -2); - const int64 n = ShapeUtil::GetDimension(a_shape, -1); - - const int64 num_batch_dims = num_dims - 2; - std::vector batch_dims(num_batch_dims); - for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); - } - - std::vector batch_dim_indices(num_batch_dims); - std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); - - auto qr_body_fn = [&](XlaOp j, absl::Span values, - XlaBuilder* builder) -> StatusOr> { - auto a = values[0]; - auto taus = values[1]; - - // v, tau, beta = house(a[:, j], j) - auto x = DynamicSliceInMinorDims(a, {j}, {1}); - XlaOp v, tau, beta; - TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j, - batch_dims, m, &v, &tau, &beta)); - - const int64 minor_dim = batch_dims.size(); - auto iota_mn = Iota( - builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})), - minor_dim + 1); - - std::vector shape = batch_dims; - shape.push_back(1); - shape.push_back(m); - auto v_broadcast = Reshape(v, shape); - // a[:, j+1:] -= tau * (v[:, np.newaxis] @ (v[np.newaxis, :] @ a[:, j+1:])) - // We use masking rather than a loop-variant shape to handle the j+1: - // indexing. - auto vva = BatchDot(v_broadcast, Select(Lt(j, iota_mn), a, ZerosLike(a)), - precision); - vva = BatchDot(v_broadcast, true, vva, false, precision); - a = a - Mul(tau, vva, - /*broadcast_dimensions=*/batch_dim_indices); - - // a[j, j] = beta - // a[j+1:,j] = v[j+1:] - auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1}); - auto predecessor_mask = ConvertElementType(Lt(iota, j), type); - auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), - std::vector(batch_dims.size(), 1)); - auto successor_mask = Gt(Iota(a.builder(), S32, m), j); - auto new_x = Mul(x, predecessor_mask, - /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + - Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); - new_x = Add( - new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)), - /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim})); - // Update a[:,j] - std::vector dim_ids(num_dims); - std::iota(dim_ids.begin(), dim_ids.end(), 0); - new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}), - /*broadcast_dimensions=*/dim_ids); - a = Select(Eq(iota_mn, j), new_x, a); - - // taus[j] = tau - std::vector tau_broadcast_dims(batch_dims.size()); - std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0); - - auto iota_n = - Iota(builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n})), - minor_dim); - auto taus_zeros = ZerosLike(taus); - auto taus_update = Select( - Eq(iota_n, j), - Add(taus_zeros, tau, /*broadcast_dimensions=*/tau_broadcast_dims), - taus_zeros); - taus = taus + taus_update; - return std::vector{a, taus}; - }; - - auto taus = Zeros(builder, - ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {n}))); - - TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn, - {a, taus}, "qr", builder)); - - QRBlockResult result; - result.a = values[0]; - result.taus = values[1]; - return result; -} - -// Computes T such that (I - Y @ T @ Y^t) is a product of the elementary -// Householder reflectors given by `vs` and `taus`. -// -// Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY -// representation for products of Householder transformations." SIAM Journal on -// Scientific and Statistical Computing 10.1 (1989): 53-57. -// -// def compact_wy(vs, taus): -// m, n = vs.shape[-2:] -// t = np.eye(n) * -taus -// # We premultiply Y.T @ vs, since we would prefer to compute a single matrix -// # multiplication to many matrix-vector products. -// vtv = -taus[None, :] * np.triu(vs.T @ vs, 1) + np.eye(n) -// for i in range(1, n): -// t[:, i] = np.dot(t, vtv[:, i]) -// return t -StatusOr CompactWYRepresentation(PrimitiveType type, - absl::Span batch_dims, - XlaOp vs, XlaOp taus, int64 m, int64 n, - PrecisionConfig::Precision precision) { - std::vector batch_dim_indices(batch_dims.size()); - std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); - int64 n_index = batch_dims.size() + 1; - - auto body_fn = [&](XlaOp j, absl::Span values, - XlaBuilder* builder) -> StatusOr> { - // w has shape [..., m, n] - auto t = values[0]; - const auto vtv = values[1]; - - // Want j values in range [1, ... n). - j = j + ConstantR0(builder, 1); - - // yv has shape [..., n, 1] - auto yv = DynamicSliceInMinorDims(vtv, {j}, {1}); - - // wyv has shape [..., n, 1] - auto z = BatchDot(t, yv, precision); - - t = DynamicUpdateSliceInMinorDims(t, z, {j}); - - return std::vector{t, vtv}; - }; - - XlaBuilder* builder = vs.builder(); - - auto tau_scale = BroadcastInDim(-taus, ConcatVectors(batch_dims, {1, n}), - ConcatVectors(batch_dim_indices, {n_index})); - - auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims); - auto t = eye * tau_scale; - - auto vtv = - BatchDot(vs, /*transpose_x=*/true, vs, /*transpose_y=*/false, precision); - vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv) * tau_scale; - vtv = vtv + eye; - - TF_ASSIGN_OR_RETURN( - auto values, ForEachIndex(n - 1, S32, body_fn, {t, vtv}, "wy", builder)); - return values[0]; -} } // namespace @@ -332,8 +62,6 @@ StatusOr QRDecomposition( return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s", a_shape.ToString()); } - PrimitiveType type = a_shape.element_type(); - const int64 m = ShapeUtil::GetDimension(a_shape, -2); const int64 n = ShapeUtil::GetDimension(a_shape, -1); const int64 p = std::min(m, n); @@ -343,60 +71,21 @@ StatusOr QRDecomposition( block_size); } - const int64 num_batch_dims = num_dims - 2; - std::vector batch_dims(num_batch_dims); - for (int i = 0; i < num_batch_dims; ++i) { - batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); - } + Shape q_shape = a_shape; + q_shape.mutable_dimensions().back() = m; - auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); - for (int64 i = 0; i < p; i += block_size) { - int64 k = std::min(block_size, p - i); - - auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k}); - TF_ASSIGN_OR_RETURN(auto qr_block, QRBlock(a_block, precision)); - auto y = Add( - IdentityMatrix(builder, type, m - i, k), - Select(TriangleMask(qr_block.a, -1), qr_block.a, ZerosLike(qr_block.a)), - /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}); - - a = UpdateSliceInMinorDims(a, qr_block.a, {i, i}); - - // Compute the I + Y @ T @ Y^t block representation of a product of - // Householder matrices. - TF_ASSIGN_OR_RETURN( - auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus, - m - i, k, precision)); - - // a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:]) - auto yt = - BatchDot(y, /*transpose_x=*/false, t, /*transpose_y=*/true, precision); - auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); - auto a_update = BatchDot(y, /*transpose_x=*/true, a_panel, - /*transpose_y=*/false, precision); - a_update = BatchDot(yt, a_update, precision); - a_panel = a_panel + a_update; - a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); - - // q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T - auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); - auto q_update = BatchDot(q_panel, y, precision); - q_update = BatchDot(q_update, /*transpose_x=*/false, yt, - /*transpose_y=*/true, precision); - q_panel = q_panel + q_update; - q = UpdateSliceInMinorDims(q, q_panel, {0, i}); - } - QRDecompositionResult result; + Shape qr_shape = ShapeUtil::MakeTupleShape({q_shape, a_shape}); + auto qr = CustomCall(a.builder(), "QrDecomposition", {a}, qr_shape); + auto q = GetTupleElement(qr, 0); + auto r = GetTupleElement(qr, 1); // full_matrices is false when only a partial result in needed. Slice to the // needed dimensions here. if (!full_matrices) { q = SliceInMinorDims(q, {0, 0}, {m, p}); - a = SliceInMinorDims(a, {0, 0}, {p, n}); + r = SliceInMinorDims(r, {0, 0}, {p, n}); } - result.q = q; - result.r = UpperTriangle(a); - return result; + return QRDecompositionResult{q, r}; } } // namespace xla diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 13233262c12..8031477e7aa 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -1952,6 +1952,29 @@ cc_library( ], ) +cc_library( + name = "qr_expander", + srcs = ["qr_expander.cc"], + hdrs = ["qr_expander.h"], + deps = [ + ":op_expander_pass", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/client/lib:loops", + "//tensorflow/compiler/xla/client/lib:math", + "//tensorflow/compiler/xla/client/lib:matrix", + "//tensorflow/compiler/xla/client/lib:slicing", + "//tensorflow/core:lib", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + cc_library( name = "convolution_4d_expander", srcs = ["convolution_4d_expander.cc"], diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index d45f1c1aeb9..31c03483a9e 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -170,6 +170,7 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:cholesky_expander", + "//tensorflow/compiler/xla/service:qr_expander", "//tensorflow/compiler/xla/service:conditional_simplifier", "//tensorflow/compiler/xla/service:convolution_group_converter", "//tensorflow/compiler/xla/service:dot_decomposer", diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index e6c72e60636..0260b3926c7 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -103,6 +103,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logistic_expander.h" #include "tensorflow/compiler/xla/service/map_inliner.h" +#include "tensorflow/compiler/xla/service/qr_expander.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h" #include "tensorflow/compiler/xla/service/rng_expander.h" @@ -281,6 +282,7 @@ Status CpuCompiler::RunHloPassesThroughLayoutAssn( pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); // Inline computations with a single call site. diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index aa316404715..b74945246ce 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -1230,6 +1230,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:logistic_expander", + "//tensorflow/compiler/xla/service:qr_expander", "//tensorflow/compiler/xla/service:rng_bit_generator_expander", "//tensorflow/compiler/xla/service:rng_expander", "//tensorflow/compiler/xla/service:slice_sinker", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b8b127c8075..7b0c0974f9c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -92,6 +92,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/logistic_expander.h" +#include "tensorflow/compiler/xla/service/qr_expander.h" #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h" #include "tensorflow/compiler/xla/service/rng_expander.h" #include "tensorflow/compiler/xla/service/slice_sinker.h" @@ -151,6 +152,8 @@ Status GpuCompiler::OptimizeHloModule( pipeline.AddPass(); pipeline.AddPass(); + // TODO(phawkins): replace QR decompositions with calls to cuSOLVER. + pipeline.AddPass(); pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 46767888916..4b58a0cb04a 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -53,6 +53,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_subcomputation_unification", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/compiler/xla/service:map_inliner", + "//tensorflow/compiler/xla/service:qr_expander", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:triangular_solve_expander", "//tensorflow/compiler/xla/service:while_loop_simplifier", diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index a059482d832..3f3e74dbb62 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/service/map_inliner.h" +#include "tensorflow/compiler/xla/service/qr_expander.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/triangular_solve_expander.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" @@ -82,6 +83,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { pipeline.AddPass(); pipeline.AddPass(); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass( diff --git a/tensorflow/compiler/xla/service/qr_expander.cc b/tensorflow/compiler/xla/service/qr_expander.cc new file mode 100644 index 00000000000..8bc06f95722 --- /dev/null +++ b/tensorflow/compiler/xla/service/qr_expander.cc @@ -0,0 +1,434 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/qr_expander.h" + +#include +#include + +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" +#include "tensorflow/compiler/xla/client/lib/loops.h" +#include "tensorflow/compiler/xla/client/lib/math.h" +#include "tensorflow/compiler/xla/client/lib/matrix.h" +#include "tensorflow/compiler/xla/client/lib/slicing.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +namespace { + +std::vector ConcatVectors(absl::Span xs, + absl::Span ys) { + std::vector output; + output.reserve(xs.size() + ys.size()); + std::copy(xs.begin(), xs.end(), std::back_inserter(output)); + std::copy(ys.begin(), ys.end(), std::back_inserter(output)); + return output; +} + +// Computes a Householder reflection of the form: +// H = I - tau v v.T. +// such that +// H . ( x1 ) = ( x1 ) +// ( x2 ) = ( x2 ) +// ( ... ) = ( ... ) +// ( xk ) = ( beta ) +// ( ... ) ( 0 ) +// ( ... ) ( 0 ) +// Unlike the usual formulation, we allow the caller to supply 'k' rather than +// only providing the relevant part of 'x' to maintain XLA's static shape +// invariant. In addition, the implementation supports batching. +// Pseudo-code, without batching: +// alpha = x[k] +// x_copy = np.copy(x) +// x_copy[:k+1] = 0 +// xnorm = norm2(x_copy) +// if xnorm == 0: +// beta = alpha +// tau = 0 +// v = np.zeros_like(x) +// else: +// beta = - np.sign(alpha) * dlapy2(alpha, xnorm) +// tau = (beta - alpha) / beta +// v = x / (alpha - beta) +// v[k] = 1 +// return (v, tau, beta) +// TODO(phawkins): LAPACK's xLARFG implementation has code for handling +// overflows in the norm/beta calculations. Perhaps do the same here. +Status House(XlaOp x, XlaOp k, absl::Span batch_dims, + const int64 m, XlaOp* v, XlaOp* tau, XlaOp* beta) { + XlaBuilder* const builder = x.builder(); + TF_ASSIGN_OR_RETURN(Shape x_shape, builder->GetShape(x)); + const PrimitiveType type = x_shape.element_type(); + + std::vector batch_dim_ids(batch_dims.size()); + std::iota(batch_dim_ids.begin(), batch_dim_ids.end(), 0); + const int64 minor_dim = batch_dims.size(); + + XlaOp zero = ScalarLike(x, 0.0); + XlaOp one = ScalarLike(x, 1.0); + + // alpha = x[k] + XlaOp alpha = Reshape(DynamicSliceInMinorDims(x, {k}, {1}), batch_dims); + + // Compute x[k+1:] (padded with zeros in elements 0..k) + XlaOp iota = Iota(builder, S32, m); + XlaOp x_after_k = Mul(x, ConvertElementType(Gt(iota, k), type), + /*broadcast_dimensions=*/{minor_dim}); + + // sigma = np.dot(x[k+1:], x[k+1:]) + // TODO(phawkins): this calculation may be numerically unstable. + auto sigma = Reduce(x_after_k * x_after_k, zero, + CreateScalarAddComputation(type, builder), {minor_dim}); + // mu = np.sqrt(x[k]*x[k] + sigma) + auto mu = Sqrt(Square(alpha) + sigma); + + auto sigma_is_zero = Eq(sigma, zero); + + *beta = Select(sigma_is_zero, alpha, Select(Lt(alpha, zero), one, -one) * mu); + *tau = Select(sigma_is_zero, Broadcast(zero, batch_dims), + (*beta - alpha) / *beta); + auto divisor = + Select(sigma_is_zero, Broadcast(one, batch_dims), alpha - *beta); + + auto e_k = Broadcast(ConvertElementType(Eq(iota, k), type), + std::vector(batch_dims.size(), 1)); + + // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor + // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. + *v = e_k + Div(x_after_k, divisor, /*broadcast_dimensions=*/batch_dim_ids); + return Status::OK(); +} + +} // namespace + +// Householder QR decomposition. Algorithm 5.2.1 from Golub and Van +// Loan "Matrix Computations", 4th Edition. This is an unblocked implementation +// used as an inner routine of the blocked implementation. +// Algorithm is adapted slightly so the shapes inside the loop are static, at +// the cost of some redundant computation. Since this is used as an inner block +// kernel, accumulates the Householder transformations (vs, taus) rather than +// the matrix q. +// Equivalent Python code, without batching: +// def qr(a): +// m = a.shape[0] +// n = a.shape[1] +// taus = np.zeros([n]) +// for j in xrange(min(m, n)): +// v, tau, beta = house(a[:, j], j) +// a[:, j+1:] -= tau * np.dot(v[:, np.newaxis], +// np.dot(v[np.newaxis, :], a[:, j+1:])) +// # Form column j explicitly rather than relying on the precision of the +// # Householder update. +// a[j, j] = beta +// a[j+1:, j] = v[j+1:] +// taus[j] = tau +// return (a, taus) +StatusOr QrExpander::QrBlock( + XlaOp a, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = a_shape.rank(); + if (num_dims < 2) { + return InvalidArgument("Argument to QR must have rank >= 2; got shape %s", + a_shape.ToString()); + } + PrimitiveType type = a_shape.element_type(); + + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); + } + + std::vector batch_dim_indices(num_batch_dims); + std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); + + auto qr_body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { + auto a = values[0]; + auto taus = values[1]; + + // v, tau, beta = house(a[:, j], j) + auto x = DynamicSliceInMinorDims(a, {j}, {1}); + XlaOp v, tau, beta; + TF_RETURN_IF_ERROR(House(Collapse(x, {num_dims - 2, num_dims - 1}), j, + batch_dims, m, &v, &tau, &beta)); + + const int64 minor_dim = batch_dims.size(); + auto iota_mn = Iota( + builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {m, n})), + minor_dim + 1); + + std::vector shape = batch_dims; + shape.push_back(1); + shape.push_back(m); + auto v_broadcast = Reshape(v, shape); + // a[:, j+1:] -= tau * (v[:, np.newaxis] @ (v[np.newaxis, :] @ a[:, j+1:])) + // We use masking rather than a loop-variant shape to handle the j+1: + // indexing. + auto vva = BatchDot(v_broadcast, Select(Lt(j, iota_mn), a, ZerosLike(a)), + precision); + vva = BatchDot(v_broadcast, true, vva, false, precision); + a = a - Mul(tau, vva, + /*broadcast_dimensions=*/batch_dim_indices); + + // a[j, j] = beta + // a[j+1:,j] = v[j+1:] + auto iota = Reshape(Iota(a.builder(), S32, m), {m, 1}); + auto predecessor_mask = ConvertElementType(Lt(iota, j), type); + auto mask = Broadcast(ConvertElementType(Eq(iota, j), type), + std::vector(batch_dims.size(), 1)); + auto successor_mask = Gt(Iota(a.builder(), S32, m), j); + auto new_x = Mul(x, predecessor_mask, + /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}) + + Mul(beta, mask, /*broadcast_dimensions=*/batch_dim_indices); + new_x = Add( + new_x, Select(Broadcast(successor_mask, batch_dims), v, ZerosLike(v)), + /*broadcast_dimensions=*/ConcatVectors(batch_dim_indices, {minor_dim})); + // Update a[:,j] + std::vector dim_ids(num_dims); + std::iota(dim_ids.begin(), dim_ids.end(), 0); + new_x = BroadcastInDim(new_x, ConcatVectors(batch_dims, {m, n}), + /*broadcast_dimensions=*/dim_ids); + a = Select(Eq(iota_mn, j), new_x, a); + + // taus[j] = tau + std::vector tau_broadcast_dims(batch_dims.size()); + std::iota(tau_broadcast_dims.begin(), tau_broadcast_dims.end(), 0); + + auto iota_n = + Iota(builder, ShapeUtil::MakeShape(S32, ConcatVectors(batch_dims, {n})), + minor_dim); + auto taus_zeros = ZerosLike(taus); + auto taus_update = Select( + Eq(iota_n, j), + Add(taus_zeros, tau, /*broadcast_dimensions=*/tau_broadcast_dims), + taus_zeros); + taus = taus + taus_update; + return std::vector{a, taus}; + }; + + auto taus = Zeros( + builder, + ShapeUtil::MakeShape(type, ConcatVectors(batch_dims, {std::min(m, n)}))); + + TF_ASSIGN_OR_RETURN(auto values, ForEachIndex(std::min(m, n), S32, qr_body_fn, + {a, taus}, "qr", builder)); + + QrResult result; + result.a = values[0]; + result.taus = values[1]; + return result; +} + +// Computes an upper triangular matrix T such that (I - Y @ T @ Y^t) is a +// product of the elementary Householder reflectors given by `vs` and `taus`. +// +// Schreiber, Robert, and Charles Van Loan. "A storage-efficient WY +// representation for products of Householder transformations." SIAM Journal on +// Scientific and Statistical Computing 10.1 (1989): 53-57. +// +// def compact_wy(vs, taus): +// m, n = vs.shape[-2:] +// t = np.eye(n) * -taus +// # We premultiply Y.T @ vs, since we would prefer to compute a single matrix +// # multiplication to many matrix-vector products. +// vtv = -taus[None, :] * np.triu(vs.T @ vs, 1) + np.eye(n) +// for i in range(1, n): +// t[:, i] = scipy.linalg.blas.strmm(t, vtv[:, i]) +// return t +StatusOr QrExpander::CompactWYRepresentation( + PrimitiveType type, absl::Span batch_dims, XlaOp vs, + XlaOp taus, int64 m, int64 n, PrecisionConfig::Precision precision) { + XlaBuilder* builder = vs.builder(); + + std::vector batch_dim_indices(batch_dims.size()); + std::iota(batch_dim_indices.begin(), batch_dim_indices.end(), 0); + int64 n_index = batch_dims.size() + 1; + + auto body_fn = [&](XlaOp j, absl::Span values, + XlaBuilder* builder) -> StatusOr> { + // w has shape [..., m, n] + auto t = values[0]; + const auto vtv = values[1]; + + // yv has shape [..., n, 1] + auto yv = DynamicSliceInMinorDims(vtv, {j}, {1}); + + // z has shape [..., n, 1] + auto z = BatchDot(t, yv, precision); + + t = DynamicUpdateSliceInMinorDims(t, z, {j}); + + return std::vector{t, vtv}; + }; + + auto tau_scale = BroadcastInDim(-taus, ConcatVectors(batch_dims, {1, n}), + ConcatVectors(batch_dim_indices, {n_index})); + + auto eye = Broadcast(IdentityMatrix(builder, type, n, n), batch_dims); + auto t = eye; + + auto vtv = + BatchDot(vs, /*transpose_x=*/true, vs, /*transpose_y=*/false, precision); + vtv = Select(TriangleMask(vtv, 0), ZerosLike(vtv), vtv); + vtv = (vtv + eye) * tau_scale; + + TF_ASSIGN_OR_RETURN(auto values, + ForEachIndex(n, S32, body_fn, {t, vtv}, "wy", builder)); + return values[0]; +} + +// Block Householder QR Factorization. Algorithm 5.2.2 of Golub and van Loan. +// def qr_blocked(a, block_size): +// m = a.shape[0] +// n = a.shape[1] +// q = np.eye(m) +// for i in xrange(0, min(m, n), block_size): +// k = min(block_size, min(m, n) - s) +// (a, taus) = qr(a[i:, i:i+k]) +// y = np.eye(m, n) + np.tril(a, -1) +// t = CompactWYRepresentation(vs, taus, m-i, k) +// a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:]) +// q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T +// return (q, a) +StatusOr QrExpander::BuildQrDecomposition( + XlaOp a, int64 block_size, PrecisionConfig::Precision precision) { + XlaBuilder* builder = a.builder(); + TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a)); + const int num_dims = a_shape.rank(); + if (num_dims < 2) { + return InvalidArgument("Arguments to QR must have rank >= 2: got shape %s", + a_shape.ToString()); + } + PrimitiveType type = a_shape.element_type(); + + const int64 m = ShapeUtil::GetDimension(a_shape, -2); + const int64 n = ShapeUtil::GetDimension(a_shape, -1); + const int64 p = std::min(m, n); + + if (block_size < 1) { + return InvalidArgument("block_size argument to QR must be >= 1; got %d", + block_size); + } + + const int64 num_batch_dims = num_dims - 2; + std::vector batch_dims(num_batch_dims); + for (int i = 0; i < num_batch_dims; ++i) { + batch_dims[i] = ShapeUtil::GetDimension(a_shape, i); + } + + auto q = Broadcast(IdentityMatrix(builder, type, m, m), batch_dims); + for (int64 i = 0; i < p; i += block_size) { + int64 k = std::min(block_size, p - i); + + auto a_block = SliceInMinorDims(a, {i, i}, {m, i + k}); + TF_ASSIGN_OR_RETURN(auto qr_block, QrBlock(a_block, precision)); + auto y = Add( + IdentityMatrix(builder, type, m - i, k), + Select(TriangleMask(qr_block.a, -1), qr_block.a, ZerosLike(qr_block.a)), + /*broadcast_dimensions=*/{num_dims - 2, num_dims - 1}); + + a = UpdateSliceInMinorDims(a, qr_block.a, {i, i}); + + // Compute the I + Y @ T @ Y^t block representation of a product of + // Householder matrices. + TF_ASSIGN_OR_RETURN( + auto t, CompactWYRepresentation(type, batch_dims, y, qr_block.taus, + m - i, k, precision)); + + // a[i:, i+k:] += (y @ t.T) @ (y.T @ a[i:, i+k:]) + auto yt = + BatchDot(y, /*transpose_x=*/false, t, /*transpose_y=*/true, precision); + auto a_panel = SliceInMinorDims(a, {i, i + k}, {m, n}); + auto a_update = BatchDot(y, /*transpose_x=*/true, a_panel, + /*transpose_y=*/false, precision); + a_update = BatchDot(yt, a_update, precision); + a_panel = a_panel + a_update; + a = UpdateSliceInMinorDims(a, a_panel, {i, i + k}); + + // q[:, i:] += (q[:, i:] @ y) @ (y @ t.T).T + auto q_panel = SliceInMinorDims(q, {0, i}, {m, m}); + auto q_update = BatchDot(q_panel, y, precision); + q_update = BatchDot(q_update, /*transpose_x=*/false, yt, + /*transpose_y=*/true, precision); + q_panel = q_panel + q_update; + q = UpdateSliceInMinorDims(q, q_panel, {0, i}); + } + + return Tuple(builder, {q, UpperTriangle(a)}); +} + +bool QrExpander::InstructionMatchesPattern(HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == "QrDecomposition"; +} + +StatusOr QrExpander::ExpandInstruction( + HloInstruction* instruction) { + const string name = + absl::StrFormat("xla.qr_%s", instruction->operand(0)->shape().ToString()); + + HloModule* module = instruction->parent()->parent(); + + HloComputation*& computation = + computation_cache_.emplace(name, nullptr).first->second; + if (!computation) { + // Builds a new expansion. + // + // TODO(b/62327888): We do something unusual here: we build the computation + // using the XlaBuilder API, which is nominally an XLA client API. We do + // this because the external APIs for building complicated computations + // (XlaBuilder) are much more ergonomic than the internal ones. As it turns + // out, XlaBuilder isn't really a client API—what it does is build a + // HloModuleProto protocol buffer, that we can then deserialize and clone + // into our HloModule. Ideally we would avoid the protocol buffer step; + // that is left as an exercise for future work. + XlaBuilder builder(name); + XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a"); + TF_ASSIGN_OR_RETURN( + XlaOp l, BuildQrDecomposition(a, + /*block_size=*/128, + /*precision=*/PrecisionConfig::HIGHEST)); + + TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build(l)); + + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + xla_computation.GetProgramShape()); + HloModuleConfig config(program_shape); + TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto( + xla_computation.proto(), config)); + HloCloneContext context(module); + computation = + module->DeepCloneComputation(new_module->entry_computation(), &context); + } + + return instruction->parent()->AddInstruction(HloInstruction::CreateCall( + instruction->shape(), instruction->operands(), computation)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/qr_expander.h b/tensorflow/compiler/xla/service/qr_expander.h new file mode 100644 index 00000000000..669ace39efb --- /dev/null +++ b/tensorflow/compiler/xla/service/qr_expander.h @@ -0,0 +1,61 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_QR_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_QR_EXPANDER_H_ + +#include "absl/container/flat_hash_map.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/op_expander_pass.h" + +namespace xla { + +class QrExpander : public OpExpanderPass { + public: + absl::string_view name() const override { return "qr_expander"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + StatusOr ExpandInstruction( + HloInstruction* instruction) override; + + struct QrResult { + // The upper-triangular matrix R, packed together with the lower-triangular + // elementary Householder reflectors `vs` below the diagonal. + XlaOp a; + + // Representation of the Householder matrices I - beta v v.T + XlaOp taus; // Shape: [..., min(m, n)] + }; + + virtual StatusOr QrBlock(XlaOp a, + PrecisionConfig::Precision precision); + + virtual StatusOr CompactWYRepresentation( + PrimitiveType type, absl::Span batch_dims, XlaOp vs, + XlaOp taus, int64 m, int64 n, PrecisionConfig::Precision precision); + + private: + StatusOr BuildQrDecomposition(XlaOp a, int64 block_size, + PrecisionConfig::Precision precision); + + // Mapping from op signatures to existing computations. + absl::flat_hash_map computation_cache_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_QR_EXPANDER_H_