[XLA] Refactor triangular solve expander to make InvertDiagBlocks overridable.
Pad with an identity matrix instead of zeros to avoid NaN problems. PiperOrigin-RevId: 328865089 Change-Id: I4881650b35c5ee14236e73c8514d4b5c03e31eee
This commit is contained in:
parent
4de50499e2
commit
e47e7057cf
@ -89,16 +89,23 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
|
||||
// The last block might be smaller than the block size,
|
||||
// so we will need to pad it
|
||||
if (n % block_size != 0) {
|
||||
// Pad with zeros
|
||||
// Pad with identity matrix.
|
||||
auto last_blocks =
|
||||
SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n});
|
||||
PaddingConfig config = MakeNoPaddingConfig(ndims);
|
||||
int64 padding = block_size - n % block_size;
|
||||
config.mutable_dimensions(ndims - 1)->set_edge_padding_high(padding);
|
||||
config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding);
|
||||
last_blocks =
|
||||
Pad(last_blocks, Zero(builder, shape.element_type()), config);
|
||||
|
||||
auto eye =
|
||||
IdentityMatrix(builder, shape.element_type(), padding, padding);
|
||||
config = MakeNoPaddingConfig(ndims);
|
||||
config.mutable_dimensions(ndims - 2)->set_edge_padding_low(n %
|
||||
block_size);
|
||||
eye = Pad(eye, Zero(builder, shape.element_type()), config);
|
||||
last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
|
||||
|
||||
// Add a singleton dimension
|
||||
// i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size]
|
||||
TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks));
|
||||
@ -121,134 +128,6 @@ XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower, bool transpose_a,
|
||||
bool conjugate_a,
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = diag_blocks.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
// Input is a batch of square lower triangular square matrices. Its shape is
|
||||
// (..., size, size). We resize this to (num_blocks, size, size).
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
|
||||
int64 block_size = ShapeUtil::GetDimension(shape, -1);
|
||||
int64 num_blocks = ShapeUtil::ElementsIn(shape) /
|
||||
tensorflow::MathUtil::IPow(block_size, 2);
|
||||
diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
|
||||
|
||||
// The input must be triangular because we rely on that when doing
|
||||
// multiplications later on
|
||||
diag_blocks = Triangle(diag_blocks, /*lower=*/lower);
|
||||
|
||||
// Rescale blocks to be unit triangular, but avoid dividing by
|
||||
// zero (which can happen if the last block was padded) otherwise it will
|
||||
// introduce nans which will propagate
|
||||
auto diags = GetMatrixDiagonal(diag_blocks);
|
||||
auto ones = FullLike(diags, 1);
|
||||
diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
|
||||
auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
|
||||
|
||||
// We can now use the fact that for an upper triangular matrix
|
||||
// [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
|
||||
// L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
|
||||
// have been rescaled to be unit triangular, so L22 = L22' = 1.
|
||||
|
||||
// Initialize the output matrix with -1s on the diagonal. We use -1 instead
|
||||
// of 1 because we cannot do matrix-vector multiplies with variable shapes
|
||||
// inside of a loop, or do irregularly shaped in-place updates. Hence,
|
||||
// L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
|
||||
// entire row i.e. we calculate
|
||||
// [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
|
||||
// which means [L21 L22 0] <- [-L21 * L11', L22, 0].
|
||||
auto identity =
|
||||
IdentityMatrix(builder, shape.element_type(), block_size, block_size);
|
||||
auto neg_identity = -identity;
|
||||
|
||||
// The first or last diagonal element should be set to 1 instead of -1
|
||||
// though, since we never update it
|
||||
auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
|
||||
auto start_index = ConstantR0<int>(builder, (lower) ? 0 : block_size - 1);
|
||||
auto output_block =
|
||||
DynamicUpdateSlice(neg_identity, pos_one,
|
||||
/*start_indices=*/{start_index, start_index});
|
||||
|
||||
// Broadcast diag([1, -1, -1, ...]) to every block
|
||||
XlaOp output = Broadcast(output_block,
|
||||
/*broadcast_sizes=*/{num_blocks});
|
||||
|
||||
// Now we construct a loop that performs matrix-vector multiplications
|
||||
// inverting the blocks one row at a time
|
||||
std::vector<Shape> tuple_shapes = {
|
||||
// The loop iteration counter is a scalar, incremented each iteration.
|
||||
ShapeUtil::MakeShape(S32, {}),
|
||||
// The output has the shape of A, with one row updated each iteration.
|
||||
ShapeUtil::MakeShape(shape.element_type(),
|
||||
{num_blocks, block_size, block_size}),
|
||||
// The input is a loop invariant.
|
||||
ShapeUtil::MakeShape(shape.element_type(),
|
||||
{num_blocks, block_size, block_size})};
|
||||
Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
|
||||
|
||||
auto init_i = One(builder, S32);
|
||||
auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
|
||||
|
||||
// Construct the loop condition function.
|
||||
std::unique_ptr<XlaBuilder> condb =
|
||||
builder->CreateSubBuilder("InvertDiagCond");
|
||||
{
|
||||
auto i = GetTupleElement(
|
||||
Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
|
||||
Lt(i, ConstantR0<int32>(condb.get(), block_size));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
|
||||
|
||||
// Construct the loop body function.
|
||||
std::unique_ptr<XlaBuilder> bodyb =
|
||||
builder->CreateSubBuilder("InvertDiagBody");
|
||||
{
|
||||
auto input_tuple =
|
||||
Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
|
||||
|
||||
auto i = GetTupleElement(input_tuple, 0);
|
||||
auto body_out = GetTupleElement(input_tuple, 1);
|
||||
auto body_input = GetTupleElement(input_tuple, 2);
|
||||
|
||||
auto zero = ConstantR0<int32>(bodyb.get(), 0);
|
||||
auto j = (lower) ? i : ScalarLike(i, block_size - 1) - i;
|
||||
auto input_row =
|
||||
DynamicSlice(body_input, {zero, j, zero},
|
||||
/*slice_sizes=*/{num_blocks, 1, block_size});
|
||||
|
||||
// We want -L21 L11^{-1}
|
||||
DotDimensionNumbers dnums;
|
||||
dnums.add_lhs_batch_dimensions(0);
|
||||
dnums.add_rhs_batch_dimensions(0);
|
||||
dnums.add_lhs_contracting_dimensions(2);
|
||||
dnums.add_rhs_contracting_dimensions(1);
|
||||
PrecisionConfig precision_proto;
|
||||
precision_proto.add_operand_precision(precision);
|
||||
precision_proto.add_operand_precision(precision);
|
||||
auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
|
||||
|
||||
body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
|
||||
|
||||
auto next_i = i + ScalarLike(i, 1);
|
||||
Tuple(bodyb.get(), {next_i, body_out, body_input});
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
|
||||
|
||||
// Construct the While loop and return the result,
|
||||
// return while_loop(cond_fun, body_fun, init)[1]
|
||||
auto invert_while = While(cond, body, init);
|
||||
auto inv_diag_blocks = GetTupleElement(invert_while, 1);
|
||||
|
||||
// Undo the scaling
|
||||
inv_diag_blocks = Div(inv_diag_blocks, diags,
|
||||
/*broadcast_dimensions=*/{0, 1});
|
||||
|
||||
// Reshape back to original batch major dimensions
|
||||
return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions()));
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
|
||||
bool left_side, bool lower,
|
||||
bool transpose_a, bool conjugate_a,
|
||||
@ -357,10 +236,140 @@ XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
||||
bool transpose_a, bool conjugate_a,
|
||||
bool unit_diagonal, int64 block_size,
|
||||
PrecisionConfig::Precision precision) {
|
||||
} // namespace
|
||||
|
||||
XlaOp TriangularSolveExpander::InvertDiagonalBlocks(
|
||||
XlaOp diag_blocks, bool lower_triangular,
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = diag_blocks.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
// Input is a batch of square lower triangular square matrices. Its shape is
|
||||
// (..., size, size). We resize this to (num_blocks, size, size).
|
||||
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
|
||||
int64 block_size = ShapeUtil::GetDimension(shape, -1);
|
||||
int64 num_blocks = ShapeUtil::ElementsIn(shape) /
|
||||
tensorflow::MathUtil::IPow(block_size, 2);
|
||||
diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
|
||||
|
||||
// The input must be triangular because we rely on that when doing
|
||||
// multiplications later on
|
||||
diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular);
|
||||
|
||||
// Rescale blocks to be unit triangular, but avoid dividing by
|
||||
// zero (which can happen if the last block was padded) otherwise it will
|
||||
// introduce nans which will propagate
|
||||
auto diags = GetMatrixDiagonal(diag_blocks);
|
||||
auto ones = FullLike(diags, 1);
|
||||
diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
|
||||
auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
|
||||
|
||||
// We can now use the fact that for an upper triangular matrix
|
||||
// [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
|
||||
// L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
|
||||
// have been rescaled to be unit triangular, so L22 = L22' = 1.
|
||||
|
||||
// Initialize the output matrix with -1s on the diagonal. We use -1 instead
|
||||
// of 1 because we cannot do matrix-vector multiplies with variable shapes
|
||||
// inside of a loop, or do irregularly shaped in-place updates. Hence,
|
||||
// L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
|
||||
// entire row i.e. we calculate
|
||||
// [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
|
||||
// which means [L21 L22 0] <- [-L21 * L11', L22, 0].
|
||||
auto identity =
|
||||
IdentityMatrix(builder, shape.element_type(), block_size, block_size);
|
||||
auto neg_identity = -identity;
|
||||
|
||||
// The first or last diagonal element should be set to 1 instead of -1
|
||||
// though, since we never update it
|
||||
auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
|
||||
auto start_index =
|
||||
ConstantR0<int>(builder, lower_triangular ? 0 : block_size - 1);
|
||||
auto output_block =
|
||||
DynamicUpdateSlice(neg_identity, pos_one,
|
||||
/*start_indices=*/{start_index, start_index});
|
||||
|
||||
// Broadcast diag([1, -1, -1, ...]) to every block
|
||||
XlaOp output = Broadcast(output_block,
|
||||
/*broadcast_sizes=*/{num_blocks});
|
||||
|
||||
// Now we construct a loop that performs matrix-vector multiplications
|
||||
// inverting the blocks one row at a time
|
||||
std::vector<Shape> tuple_shapes = {
|
||||
// The loop iteration counter is a scalar, incremented each iteration.
|
||||
ShapeUtil::MakeShape(S32, {}),
|
||||
// The output has the shape of A, with one row updated each iteration.
|
||||
ShapeUtil::MakeShape(shape.element_type(),
|
||||
{num_blocks, block_size, block_size}),
|
||||
// The input is a loop invariant.
|
||||
ShapeUtil::MakeShape(shape.element_type(),
|
||||
{num_blocks, block_size, block_size})};
|
||||
Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
|
||||
|
||||
auto init_i = One(builder, S32);
|
||||
auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
|
||||
|
||||
// Construct the loop condition function.
|
||||
std::unique_ptr<XlaBuilder> condb =
|
||||
builder->CreateSubBuilder("InvertDiagCond");
|
||||
{
|
||||
auto i = GetTupleElement(
|
||||
Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
|
||||
Lt(i, ConstantR0<int32>(condb.get(), block_size));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
|
||||
|
||||
// Construct the loop body function.
|
||||
std::unique_ptr<XlaBuilder> bodyb =
|
||||
builder->CreateSubBuilder("InvertDiagBody");
|
||||
{
|
||||
auto input_tuple =
|
||||
Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
|
||||
|
||||
auto i = GetTupleElement(input_tuple, 0);
|
||||
auto body_out = GetTupleElement(input_tuple, 1);
|
||||
auto body_input = GetTupleElement(input_tuple, 2);
|
||||
|
||||
auto zero = ConstantR0<int32>(bodyb.get(), 0);
|
||||
auto j = lower_triangular ? i : ScalarLike(i, block_size - 1) - i;
|
||||
auto input_row =
|
||||
DynamicSlice(body_input, {zero, j, zero},
|
||||
/*slice_sizes=*/{num_blocks, 1, block_size});
|
||||
|
||||
// We want -L21 L11^{-1}
|
||||
DotDimensionNumbers dnums;
|
||||
dnums.add_lhs_batch_dimensions(0);
|
||||
dnums.add_rhs_batch_dimensions(0);
|
||||
dnums.add_lhs_contracting_dimensions(2);
|
||||
dnums.add_rhs_contracting_dimensions(1);
|
||||
PrecisionConfig precision_proto;
|
||||
precision_proto.add_operand_precision(precision);
|
||||
precision_proto.add_operand_precision(precision);
|
||||
auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
|
||||
|
||||
body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
|
||||
|
||||
auto next_i = i + ScalarLike(i, 1);
|
||||
Tuple(bodyb.get(), {next_i, body_out, body_input});
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
|
||||
|
||||
// Construct the While loop and return the result,
|
||||
// return while_loop(cond_fun, body_fun, init)[1]
|
||||
auto invert_while = While(cond, body, init);
|
||||
auto inv_diag_blocks = GetTupleElement(invert_while, 1);
|
||||
// Undo the scaling
|
||||
inv_diag_blocks = Div(inv_diag_blocks, diags,
|
||||
/*broadcast_dimensions=*/{0, 1});
|
||||
|
||||
// Reshape back to original batch major dimensions
|
||||
return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions()));
|
||||
});
|
||||
}
|
||||
|
||||
XlaOp TriangularSolveExpander::BuildTriangularSolve(
|
||||
XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
|
||||
bool conjugate_a, bool unit_diagonal, int64 block_size,
|
||||
PrecisionConfig::Precision precision) {
|
||||
XlaBuilder* builder = a.builder();
|
||||
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
|
||||
@ -422,6 +431,11 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
||||
return b;
|
||||
}
|
||||
|
||||
// Degenerate case: 1x1 matrices.
|
||||
if (ShapeUtil::GetDimension(a_shape, -1) == 1) {
|
||||
return unit_diagonal ? b : Div(b, MaybeConjugate(a, conjugate_a));
|
||||
}
|
||||
|
||||
// TODO(phawkins): consider pushing triangle masking into
|
||||
// InvertDiagonalBlocks.
|
||||
if (unit_diagonal) {
|
||||
@ -440,8 +454,7 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
||||
auto diag_blocks = DiagonalBlocks(a, block_size);
|
||||
|
||||
// We invert these blocks in parallel using batched matrix-vector products
|
||||
auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, transpose_a,
|
||||
conjugate_a, precision);
|
||||
auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, precision);
|
||||
|
||||
// We now find the solution using GEMMs
|
||||
auto x =
|
||||
@ -452,8 +465,6 @@ XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
TriangularSolveExpander::TriangularSolveExpander(int64 block_size)
|
||||
: block_size_(block_size) {}
|
||||
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_TRIANGULAR_SOLVE_EXPANDER_H_
|
||||
|
||||
#include "absl/container/flat_hash_map.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/op_expander_pass.h"
|
||||
|
||||
namespace xla {
|
||||
@ -35,6 +36,14 @@ class TriangularSolveExpander : public OpExpanderPass {
|
||||
StatusOr<HloInstruction*> ExpandInstruction(
|
||||
HloInstruction* instruction) override;
|
||||
|
||||
virtual XlaOp InvertDiagonalBlocks(XlaOp diag_blocks, bool lower_triangular,
|
||||
PrecisionConfig::Precision precision);
|
||||
|
||||
XlaOp BuildTriangularSolve(XlaOp a, XlaOp b, bool left_side, bool lower,
|
||||
bool transpose_a, bool conjugate_a,
|
||||
bool unit_diagonal, int64 block_size,
|
||||
PrecisionConfig::Precision precision);
|
||||
|
||||
private:
|
||||
// Block size for BuildTriangularSolve
|
||||
const int64 block_size_;
|
||||
|
Loading…
Reference in New Issue
Block a user