[XLA] Refactor triangular solve expander to make InvertDiagBlocks overridable.

Pad with an identity matrix instead of zeros to avoid NaN problems.

PiperOrigin-RevId: 328865089
Change-Id: I4881650b35c5ee14236e73c8514d4b5c03e31eee
This commit is contained in:
Peter Hawkins 2020-08-27 19:41:25 -07:00 committed by TensorFlower Gardener
parent 4de50499e2
commit e47e7057cf
2 changed files with 158 additions and 138 deletions

View File

@ -89,16 +89,23 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
// The last block might be smaller than the block size,
// so we will need to pad it
if (n % block_size != 0) {
// Pad with zeros
// Pad with identity matrix.
auto last_blocks =
SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n});
PaddingConfig config = MakeNoPaddingConfig(ndims);
int64 padding = block_size - n % block_size;
config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding);
config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding);
last_blocks =
Pad(last_blocks, Zero(builder, shape.element_type()), config);
auto eye =
IdentityMatrix(builder, shape.element_type(), padding, padding);
config = MakeNoPaddingConfig(ndims);
config.mutable_dimensions(ndims - 2)->set_edge_padding_low(n %
block_size);
eye = Pad(eye, Zero(builder, shape.element_type()), config);
last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
// Add a singleton dimension
// i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size]
TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks));
@ -121,134 +128,6 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
});
}
XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a,
bool conjugate_a,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = diag_blocks.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Input is a batch of square lower triangular square matrices. Its shape is
// (..., size, size). We resize this to (num_blocks, size, size).
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
int64 block_size = ShapeUtil::GetDimension(shape, -1);
int64 num_blocks = ShapeUtil::ElementsIn(shape) /
tensorflow::MathUtil::IPow(block_size, 2);
diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
// The input must be triangular because we rely on that when doing
// multiplications later on
diag_blocks = Triangle(diag_blocks, /*lower=*/lower);
// Rescale blocks to be unit triangular, but avoid dividing by
// zero (which can happen if the last block was padded) otherwise it will
// introduce nans which will propagate
auto diags = GetMatrixDiagonal(diag_blocks);
auto ones = FullLike(diags, 1);
diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
// We can now use the fact that for an upper triangular matrix
// [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
// L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
// have been rescaled to be unit triangular, so L22 = L22' = 1.
// Initialize the output matrix with -1s on the diagonal. We use -1 instead
// of 1 because we cannot do matrix-vector multiplies with variable shapes
// inside of a loop, or do irregularly shaped in-place updates. Hence,
// L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
// entire row i.e. we calculate
// [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
// which means [L21 L22 0] <- [-L21 * L11', L22, 0].
auto identity =
IdentityMatrix(builder, shape.element_type(), block_size, block_size);
auto neg_identity = -identity;
// The first or last diagonal element should be set to 1 instead of -1
// though, since we never update it
auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
auto start_index = ConstantR0<int>(builder, (lower) ? 0 : block_size - 1);
auto output_block =
DynamicUpdateSlice(neg_identity, pos_one,
/*start_indices=*/{start_index, start_index});
// Broadcast diag([1, -1, -1, ...]) to every block
XlaOp output = Broadcast(output_block,
/*broadcast_sizes=*/{num_blocks});
// Now we construct a loop that performs matrix-vector multiplications
// inverting the blocks one row at a time
std::vector<Shape> tuple_shapes = {
// The loop iteration counter is a scalar, incremented each iteration.
ShapeUtil::MakeShape(S32, {}),
// The output has the shape of A, with one row updated each iteration.
ShapeUtil::MakeShape(shape.element_type(),
{num_blocks, block_size, block_size}),
// The input is a loop invariant.
ShapeUtil::MakeShape(shape.element_type(),
{num_blocks, block_size, block_size})};
Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
auto init_i = One(builder, S32);
auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
// Construct the loop condition function.
std::unique_ptr<XlaBuilder> condb =
builder->CreateSubBuilder("InvertDiagCond");
{
auto i = GetTupleElement(
Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
Lt(i, ConstantR0<int32>(condb.get(), block_size));
}
TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
// Construct the loop body function.
std::unique_ptr<XlaBuilder> bodyb =
builder->CreateSubBuilder("InvertDiagBody");
{
auto input_tuple =
Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
auto i = GetTupleElement(input_tuple, 0);
auto body_out = GetTupleElement(input_tuple, 1);
auto body_input = GetTupleElement(input_tuple, 2);
auto zero = ConstantR0<int32>(bodyb.get(), 0);
auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i;
auto input_row =
DynamicSlice(body_input, {zero, j, zero},
/*slice_sizes=*/{num_blocks, 1, block_size});
// We want -L21 L11^{-1}
DotDimensionNumbers dnums;
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
PrecisionConfig precision_proto;
precision_proto.add_operand_precision(precision);
precision_proto.add_operand_precision(precision);
auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
auto next_i = i + ScalarLike(i, 1);
Tuple(bodyb.get(), {next_i, body_out, body_input});
}
TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
// Construct the While loop and return the result,
// return while_loop(cond_fun, body_fun, init)[1]
auto invert_while = While(cond, body, init);
auto inv_diag_blocks = GetTupleElement(invert_while, 1);
// Undo the scaling
inv_diag_blocks = Div(inv_diag_blocks, diags,
/*broadcast_dimensions=*/{0, 1});
// Reshape back to original batch major dimensions
return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions()));
});
}
XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
bool left_side, bool lower,
bool transpose_a, bool conjugate_a,
@ -357,10 +236,140 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
});
}
XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
bool transpose_a, bool conjugate_a,
bool unit_diagonal, int64 block_size,
PrecisionConfig::Precision precision) {
} // namespace
XlaOp TriangularSolveExpander::InvertDiagonalBlocks(
XlaOp diag_blocks, bool lower_triangular,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = diag_blocks.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Input is a batch of square lower triangular square matrices. Its shape is
// (..., size, size). We resize this to (num_blocks, size, size).
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
int64 block_size = ShapeUtil::GetDimension(shape, -1);
int64 num_blocks = ShapeUtil::ElementsIn(shape) /
tensorflow::MathUtil::IPow(block_size, 2);
diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
// The input must be triangular because we rely on that when doing
// multiplications later on
diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular);
// Rescale blocks to be unit triangular, but avoid dividing by
// zero (which can happen if the last block was padded) otherwise it will
// introduce nans which will propagate
auto diags = GetMatrixDiagonal(diag_blocks);
auto ones = FullLike(diags, 1);
diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
// We can now use the fact that for an upper triangular matrix
// [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
// L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
// have been rescaled to be unit triangular, so L22 = L22' = 1.
// Initialize the output matrix with -1s on the diagonal. We use -1 instead
// of 1 because we cannot do matrix-vector multiplies with variable shapes
// inside of a loop, or do irregularly shaped in-place updates. Hence,
// L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
// entire row i.e. we calculate
// [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
// which means [L21 L22 0] <- [-L21 * L11', L22, 0].
auto identity =
IdentityMatrix(builder, shape.element_type(), block_size, block_size);
auto neg_identity = -identity;
// The first or last diagonal element should be set to 1 instead of -1
// though, since we never update it
auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
auto start_index =
ConstantR0<int>(builder, lower_triangular ? 0 : block_size - 1);
auto output_block =
DynamicUpdateSlice(neg_identity, pos_one,
/*start_indices=*/{start_index, start_index});
// Broadcast diag([1, -1, -1, ...]) to every block
XlaOp output = Broadcast(output_block,
/*broadcast_sizes=*/{num_blocks});
// Now we construct a loop that performs matrix-vector multiplications
// inverting the blocks one row at a time
std::vector<Shape> tuple_shapes = {
// The loop iteration counter is a scalar, incremented each iteration.
ShapeUtil::MakeShape(S32, {}),
// The output has the shape of A, with one row updated each iteration.
ShapeUtil::MakeShape(shape.element_type(),
{num_blocks, block_size, block_size}),
// The input is a loop invariant.
ShapeUtil::MakeShape(shape.element_type(),
{num_blocks, block_size, block_size})};
Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
auto init_i = One(builder, S32);
auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
// Construct the loop condition function.
std::unique_ptr<XlaBuilder> condb =
builder->CreateSubBuilder("InvertDiagCond");
{
auto i = GetTupleElement(
Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
Lt(i, ConstantR0<int32>(condb.get(), block_size));
}
TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
// Construct the loop body function.
std::unique_ptr<XlaBuilder> bodyb =
builder->CreateSubBuilder("InvertDiagBody");
{
auto input_tuple =
Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
auto i = GetTupleElement(input_tuple, 0);
auto body_out = GetTupleElement(input_tuple, 1);
auto body_input = GetTupleElement(input_tuple, 2);
auto zero = ConstantR0<int32>(bodyb.get(), 0);
auto j = lower_triangular ? i : ScalarLike(i, block_size - 1) - i;
auto input_row =
DynamicSlice(body_input, {zero, j, zero},
/*slice_sizes=*/{num_blocks, 1, block_size});
// We want -L21 L11^{-1}
DotDimensionNumbers dnums;
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
PrecisionConfig precision_proto;
precision_proto.add_operand_precision(precision);
precision_proto.add_operand_precision(precision);
auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
auto next_i = i + ScalarLike(i, 1);
Tuple(bodyb.get(), {next_i, body_out, body_input});
}
TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
// Construct the While loop and return the result,
// return while_loop(cond_fun, body_fun, init)[1]
auto invert_while = While(cond, body, init);
auto inv_diag_blocks = GetTupleElement(invert_while, 1);
// Undo the scaling
inv_diag_blocks = Div(inv_diag_blocks, diags,
/*broadcast_dimensions=*/{0, 1});
// Reshape back to original batch major dimensions
return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions()));
});
}
XlaOp TriangularSolveExpander::BuildTriangularSolve(
XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
bool conjugate_a, bool unit_diagonal, int64 block_size,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
@ -422,6 +431,11 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
return b;
}
// Degenerate case: 1x1 matrices.
if (ShapeUtil::GetDimension(a_shape, -1) == 1) {
return unit_diagonal ? b : Div(b, MaybeConjugate(a, conjugate_a));
}
// TODO(phawkins): consider pushing triangle masking into
// InvertDiagonalBlocks.
if (unit_diagonal) {
@ -440,8 +454,7 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
auto diag_blocks = DiagonalBlocks(a, block_size);
// We invert these blocks in parallel using batched matrix-vector products
auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a,
conjugate_a, precision);
auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, precision);
// We now find the solution using GEMMs
auto x =
@ -452,8 +465,6 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
});
}
} // namespace
TriangularSolveExpander::TriangularSolveExpander(int64 block_size)
: block_size_(block_size) {}

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_
#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
namespace xla {
@ -35,6 +36,14 @@ class TriangularSolveExpander : public OpExpanderPass {
StatusOr<HloInstruction*> ExpandInstruction(
HloInstruction* instruction) override;
virtual XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower_triangular,
PrecisionConfig::Precision precision);
XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
bool transpose_a, bool conjugate_a,
bool unit_diagonal, int64 block_size,
PrecisionConfig::Precision precision);
private:
// Block size for BuildTriangularSolve
const int64 block_size_;