STT-tensorflow/tensorflow/compiler/xla/service/triangular_solve_expander.cc
Peter Hawkins 7a050b85d8 [XLA] Fix bug in triangular solve expander.
For batched matrix whose size was not evenly divided by the number of blocks, the HLO being produced was not shape-correct. Improves test case coverage to catch the problem.

Will fix https://github.com/google/jax/issues/4773 when incorporated into a jaxlib.

PiperOrigin-RevId: 340678796
Change-Id: Ifc230aa0bae9aec4902556c5e7829410c60c587f
2020-11-04 10:32:06 -08:00

617 lines
25 KiB
C++

/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/triangular_solve_expander.h"
#include <memory>
#include <vector>
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/lib/matrix.h"
#include "tensorflow/compiler/xla/client/lib/slicing.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/client/xla_computation.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/math/math_util.h"
namespace xla {
namespace {
// Get the diagonal blocks of the coefficient matrix
XlaOp DiagonalBlocks(XlaOp a, int64 block_size) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(a));
int ndims = shape.rank();
int64 n = ShapeUtil::GetDimension(shape, -1);
int64 num_blocks = n / block_size;
absl::Span<int64 const> batch_dims = absl::MakeConstSpan(
shape.dimensions().begin(), shape.dimensions().begin() + (ndims - 2));
XlaOp diag_blocks;
// If the coefficient matrix is exactly the block size, we just add a
// singleton dimension i.e. [..., n, n] -> [..., 1, n, n]
if (n == block_size) {
std::vector<int64> permutation(ndims);
std::iota(permutation.begin(), permutation.end(), 1);
permutation.insert(permutation.end() - 2, 0);
return Transpose(Broadcast(a, /*broadcast_sizes=*/{1}), permutation);
}
// We can grab entire blocks using gather
if (n > block_size) {
// Construct the starting indices of the diagonal blocks
auto start_indices =
Transpose(Broadcast(Mul(Iota(builder, S32, num_blocks),
ConstantR0<int32>(builder, block_size)),
/*broadcast_sizes=*/{2}),
/*permutation=*/{1, 0});
PaddingConfig padding_config =
MakeEdgePaddingConfig({{0, 0}, {ndims - 2, 0}});
start_indices =
Pad(start_indices, ConstantR0<int32>(builder, 0), padding_config);
// Gather the diagonal blocks
std::vector<int64> slice_sizes(ndims);
GatherDimensionNumbers dim_numbers;
for (int i = 0; i < ndims - 2; ++i) {
dim_numbers.add_offset_dims(i);
dim_numbers.add_start_index_map(i);
slice_sizes[i] = ShapeUtil::GetDimension(shape, i);
}
slice_sizes[ndims - 2] = slice_sizes[ndims - 1] = block_size;
dim_numbers.add_offset_dims(ndims - 1);
dim_numbers.add_offset_dims(ndims);
dim_numbers.add_start_index_map(ndims - 2);
dim_numbers.add_start_index_map(ndims - 1);
dim_numbers.set_index_vector_dim(1);
diag_blocks = Gather(a, start_indices, dim_numbers, slice_sizes);
}
// The last block might be smaller than the block size,
// so we will need to pad it
if (n % block_size != 0) {
// Pad with identity matrix.
auto last_blocks =
SliceInMinorDims(a, {n - n % block_size, n - n % block_size}, {n, n});
PaddingConfig config = MakeNoPaddingConfig(ndims);
int64 padding = block_size - n % block_size;
config.mutable_dimensions(ndims - 2)->set_edge_padding_high(padding);
last_blocks =
Pad(last_blocks, Zero(builder, shape.element_type()), config);
auto eye =
IdentityMatrix(builder, shape.element_type(), padding, padding);
config = MakeNoPaddingConfig(2);
config.mutable_dimensions(0)->set_edge_padding_low(n % block_size);
eye = Pad(eye, Zero(builder, shape.element_type()), config);
eye = Broadcast(eye, batch_dims);
last_blocks = ConcatInDim(builder, {last_blocks, eye}, ndims - 1);
// Add a singleton dimension
// i.e. [..., block_size, block_size] -> [..., 1, block_size, block_size]
TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(last_blocks));
auto shape_dims = AsInt64Slice(blocks_shape.dimensions());
auto last_blocks_dims = std::vector<int64>(ndims);
std::copy(shape_dims.begin(), shape_dims.end(), last_blocks_dims.begin());
last_blocks_dims.insert(last_blocks_dims.end() - 2, 1);
last_blocks = Reshape(last_blocks, last_blocks_dims);
// Concatenate with the other blocks if necessary
if (n > block_size) {
diag_blocks =
ConcatInDim(builder, {diag_blocks, last_blocks}, ndims - 2);
} else {
diag_blocks = last_blocks;
}
}
return diag_blocks;
});
}
XlaOp SolveWithInvertedDiagonalBlocks(XlaOp a, XlaOp b, XlaOp inv_diag_blocks,
bool left_side, bool lower,
bool transpose_a, bool conjugate_a,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape blocks_shape, builder->GetShape(inv_diag_blocks));
TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
int64 block_size = ShapeUtil::GetDimension(blocks_shape, -1);
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
int64 ndims = a_shape.rank();
int64 n = ShapeUtil::GetDimension(a_shape, -1);
int64 num_blocks = n / block_size + (n % block_size != 0);
int64 m_dim = (left_side) ? -1 : -2;
int64 m = ShapeUtil::GetDimension(b_shape, m_dim);
std::vector<XlaOp> update_ops;
int bdims = b_shape.rank();
int64 block_dim = (left_side) ? bdims - 2 : bdims - 1;
// Initialize the solution
XlaOp x;
// This loop is unrolled for performance reasons, but it could be expressed
// rolled as well since the matrices are of the same size each iteration
for (int i = 0; i < num_blocks; i++) {
// High-level intuition: We have B[i] = L[i] @ X. Since L is upper
// triangular this means B[i] = L[i, :i + 1] @ X[:i + 1]. We can split
// this into two parts: B[i] = L[i, :i] @ X[:i] + L[i, i] @ X[i] which
// can be solved for X[i] as X[i] = inv(L[i, i]) @ B[i] - L[i, :i] @ X[:i]
// Decide whether we go from first block to last or vice versa
bool backward = left_side ^ lower ^ transpose_a;
auto j = backward ? num_blocks - 1 - i : i;
// Get the size of the inverse blocks (the last one might be smaller)
int64 block = (n % block_size != 0 && j + 1 == num_blocks)
? n % block_size
: block_size;
auto inv_block =
MaybeConjugate(Collapse(SliceInMinorDims(inv_diag_blocks, {j, 0, 0},
{j + 1, block, block}),
/*dimensions=*/{ndims - 2, ndims - 1}),
conjugate_a);
// Get the corresponding row of B
int64 k = std::min((j + 1) * block_size, n);
std::vector<int64> start = {j * block_size, 0};
std::vector<int64> end = {k, m};
if (!left_side) {
std::swap(start[0], start[1]);
std::swap(end[0], end[1]);
}
auto b_row = SliceInMinorDims(b, start, end);
XlaOp remainder;
if (i == 0) {
remainder = b_row;
} else {
// This matrix multiply get rid of a lot of multiplying with zero
// (namely, X[i * block_size:] = 0), L[i, :i] @ X[:i]
if (backward) {
start = {j * block_size,
std::max(int64{0}, (num_blocks - i) * block_size)};
end = {k, n};
} else {
start = {j * block_size, 0};
end = {k, std::min(i * block_size, n)};
}
if (!left_side ^ transpose_a) {
std::swap(start[0], start[1]);
std::swap(end[0], end[1]);
}
auto a_row =
MaybeConjugate(SliceInMinorDims(a, start, end), conjugate_a);
if (left_side) {
remainder = b_row - BatchDot(a_row, transpose_a, x, false, precision);
} else {
remainder = b_row - BatchDot(x, false, a_row, transpose_a, precision);
}
}
XlaOp x_update;
if (left_side) {
x_update =
BatchDot(inv_block, transpose_a, remainder, false, precision);
} else {
x_update =
BatchDot(remainder, false, inv_block, transpose_a, precision);
}
if (i == 0) {
x = x_update;
} else {
if (backward) {
x = ConcatInDim(builder, {x_update, x}, block_dim);
} else {
x = ConcatInDim(builder, {x, x_update}, block_dim);
}
}
}
return x;
});
}
} // namespace
XlaOp TriangularSolveExpander::InvertDiagonalBlocks(
XlaOp diag_blocks, bool lower_triangular,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = diag_blocks.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
// Input is a batch of square lower triangular square matrices. Its shape is
// (..., size, size). We resize this to (num_blocks, size, size).
TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(diag_blocks));
int64 block_size = ShapeUtil::GetDimension(shape, -1);
int64 num_blocks = ShapeUtil::ElementsIn(shape) /
tensorflow::MathUtil::IPow(block_size, 2);
diag_blocks = Reshape(diag_blocks, {num_blocks, block_size, block_size});
// The input must be triangular because we rely on that when doing
// multiplications later on
diag_blocks = Triangle(diag_blocks, /*lower=*/lower_triangular);
// Rescale blocks to be unit triangular, but avoid dividing by
// zero (which can happen if the last block was padded) otherwise it will
// introduce nans which will propagate
auto diags = GetMatrixDiagonal(diag_blocks);
auto ones = FullLike(diags, 1);
diags = Select(Eq(diags, Zero(builder, shape.element_type())), ones, diags);
auto scaled_diag_blocks = Div(diag_blocks, diags, {0, 2});
// We can now use the fact that for an upper triangular matrix
// [[L11, 0], [L21, L22]], given the inverses L11' and L22', we have
// L22' = -L22' * L21 * L11'. In our case, L21 is a vector and our blocks
// have been rescaled to be unit triangular, so L22 = L22' = 1.
// Initialize the output matrix with -1s on the diagonal. We use -1 instead
// of 1 because we cannot do matrix-vector multiplies with variable shapes
// inside of a loop, or do irregularly shaped in-place updates. Hence,
// L21 <- -L22 * L21 * L11 cannot be done naively. Instead, we update the
// entire row i.e. we calculate
// [L21 L22 0] <- -[L21 L22 0] @ diag_blocks([L11', -I, -I])
// which means [L21 L22 0] <- [-L21 * L11', L22, 0].
auto identity =
IdentityMatrix(builder, shape.element_type(), block_size, block_size);
auto neg_identity = -identity;
// The first or last diagonal element should be set to 1 instead of -1
// though, since we never update it
auto pos_one = Reshape(One(builder, shape.element_type()), {1, 1});
auto start_index =
ConstantR0<int>(builder, lower_triangular ? 0 : block_size - 1);
auto output_block =
DynamicUpdateSlice(neg_identity, pos_one,
/*start_indices=*/{start_index, start_index});
// Broadcast diag([1, -1, -1, ...]) to every block
XlaOp output = Broadcast(output_block,
/*broadcast_sizes=*/{num_blocks});
// Now we construct a loop that performs matrix-vector multiplications
// inverting the blocks one row at a time
std::vector<Shape> tuple_shapes = {
// The loop iteration counter is a scalar, incremented each iteration.
ShapeUtil::MakeShape(S32, {}),
// The output has the shape of A, with one row updated each iteration.
ShapeUtil::MakeShape(shape.element_type(),
{num_blocks, block_size, block_size}),
// The input is a loop invariant.
ShapeUtil::MakeShape(shape.element_type(),
{num_blocks, block_size, block_size})};
Shape tuple_shape = ShapeUtil::MakeTupleShape(tuple_shapes);
auto init_i = One(builder, S32);
auto init = Tuple(builder, {init_i, output, scaled_diag_blocks});
// Construct the loop condition function.
std::unique_ptr<XlaBuilder> condb =
builder->CreateSubBuilder("InvertDiagCond");
{
auto i = GetTupleElement(
Parameter(condb.get(), 0, tuple_shape, "InvertDiagCondTuple"), 0);
Lt(i, ConstantR0<int32>(condb.get(), block_size));
}
TF_ASSIGN_OR_RETURN(auto cond, condb->Build());
// Construct the loop body function.
std::unique_ptr<XlaBuilder> bodyb =
builder->CreateSubBuilder("InvertDiagBody");
{
auto input_tuple =
Parameter(bodyb.get(), 0, tuple_shape, "InvertDiagBodyTuple");
auto i = GetTupleElement(input_tuple, 0);
auto body_out = GetTupleElement(input_tuple, 1);
auto body_input = GetTupleElement(input_tuple, 2);
auto zero = ConstantR0<int32>(bodyb.get(), 0);
auto j = lower_triangular ? i : ScalarLike(i, block_size - 1) - i;
auto input_row =
DynamicSlice(body_input, {zero, j, zero},
/*slice_sizes=*/{num_blocks, 1, block_size});
// We want -L21 L11^{-1}
DotDimensionNumbers dnums;
dnums.add_lhs_batch_dimensions(0);
dnums.add_rhs_batch_dimensions(0);
dnums.add_lhs_contracting_dimensions(2);
dnums.add_rhs_contracting_dimensions(1);
PrecisionConfig precision_proto;
precision_proto.add_operand_precision(precision);
precision_proto.add_operand_precision(precision);
auto update = -DotGeneral(input_row, body_out, dnums, &precision_proto);
body_out = DynamicUpdateSlice(body_out, update, {zero, j, zero});
auto next_i = i + ScalarLike(i, 1);
Tuple(bodyb.get(), {next_i, body_out, body_input});
}
TF_ASSIGN_OR_RETURN(auto body, bodyb->Build());
// Construct the While loop and return the result,
// return while_loop(cond_fun, body_fun, init)[1]
auto invert_while = While(cond, body, init);
auto inv_diag_blocks = GetTupleElement(invert_while, 1);
// Undo the scaling
inv_diag_blocks = Div(inv_diag_blocks, diags,
/*broadcast_dimensions=*/{0, 1});
// Reshape back to original batch major dimensions
return Reshape(inv_diag_blocks, AsInt64Slice(shape.dimensions()));
});
}
XlaOp TriangularSolveExpander::SolveByInvertingDiagonalBlocks(
XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
bool conjugate_a, bool unit_diagonal,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
const int64 ndims = a_shape.rank();
int64 k = ShapeUtil::GetDimension(a_shape, -1);
// TODO(phawkins): consider pushing triangle masking into
// InvertDiagonalBlocks.
if (unit_diagonal) {
// Mask everything but the subdiagonal/superdiagonal elements.
a = lower ? Select(TriangleMask(a, -1), a, ZerosLike(a))
: Select(TriangleMask(a, 0), ZerosLike(a), a);
a = xla::Add(a, IdentityMatrix(builder, a_shape.element_type(), k, k),
/*broadcast_dimensions=*/{ndims - 2, ndims - 1});
} else {
// Mask off the ignored elements of the triangular matrix a.
a = Triangle(a, lower);
}
// We find the diagonal blocks of the coefficient matrix
int64 block_size = std::min(block_size_, k);
auto diag_blocks = DiagonalBlocks(a, block_size);
// We invert these blocks in parallel using batched matrix-vector products
auto inv_diag_blocks = InvertDiagonalBlocks(diag_blocks, lower, precision);
// We now find the solution using GEMMs
return SolveWithInvertedDiagonalBlocks(a, b, inv_diag_blocks, left_side,
lower, transpose_a, conjugate_a,
precision);
});
}
// def trsm_left_lower_leftlooking(a, b):
// n = a.shape[-1]
// assert a.shape == (n, n)
// b = b.copy()
// for j in range(n):
// b[j, :] = (b[j, :] - np.dot(a[j, :j], b[:j, :])) / a[j, j]
// return b
XlaOp TriangularSolveExpander::SolveDirectly(
XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
bool conjugate_a, bool unit_diagonal,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
int64 m = ShapeUtil::GetDimension(b_shape, -2);
int64 n = ShapeUtil::GetDimension(b_shape, -1);
const int64 a_size = ShapeUtil::GetDimension(a_shape, -1);
a = MaybeConjugate(a, conjugate_a);
bool backwards = transpose_a ^ lower ^ !left_side;
for (int64 i = 0; i < a_size; ++i) {
int64 j = backwards ? i : (a_size - i - 1);
std::vector<int64> b_row_start, b_row_end;
if (left_side) {
b_row_start = {j, 0};
b_row_end = {j + 1, n};
} else {
b_row_start = {0, j};
b_row_end = {m, j + 1};
}
auto b_row = SliceInMinorDims(b, b_row_start, b_row_end);
std::vector<int64> a_start = {j, backwards ? 0 : (j + 1)};
std::vector<int64> a_end = {j + 1, backwards ? j : a_size};
if (transpose_a ^ !left_side) {
std::swap(a_start[0], a_start[1]);
std::swap(a_end[0], a_end[1]);
}
auto a_chunk = SliceInMinorDims(a, a_start, a_end);
if (left_side) {
bool which = transpose_a ^ lower;
auto b_chunk =
SliceInMinorDims(b, {which ? 0 : (j + 1), 0}, {which ? j : m, n});
b_row = b_row - BatchDot(a_chunk, /*transpose_x=*/transpose_a, b_chunk,
/*transpose_y=*/false, precision);
} else {
bool which = transpose_a ^ !lower;
auto b_chunk =
SliceInMinorDims(b, {0, which ? 0 : (j + 1)}, {m, which ? j : n});
b_row = b_row - BatchDot(b_chunk, /*transpose_x=*/false, a_chunk,
/*transpose_y=*/transpose_a, precision);
}
if (!unit_diagonal) {
auto a_diag = SliceInMinorDims(a, {j, j}, {j + 1, j + 1});
b_row = b_row / a_diag;
}
b = UpdateSliceInMinorDims(b, b_row, b_row_start);
}
return b;
});
}
XlaOp TriangularSolveExpander::BuildTriangularSolve(
XlaOp a, XlaOp b, bool left_side, bool lower, bool transpose_a,
bool conjugate_a, bool unit_diagonal, int64 block_size,
PrecisionConfig::Precision precision) {
XlaBuilder* builder = a.builder();
return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(Shape a_shape, builder->GetShape(a));
TF_ASSIGN_OR_RETURN(Shape b_shape, builder->GetShape(b));
if (a_shape.rank() != b_shape.rank()) {
return InvalidArgument(
"Arguments to TriangularSolve have shapes with different ranks: "
"%s vs. %s",
ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
}
const int64 ndims = a_shape.rank();
if (ndims < 2) {
return InvalidArgument(
"Arguments to TriangularSolve was rank %d but must have rank >= 2.",
ndims);
}
// The batch dimensions must be equal.
std::vector<int64> batch_dimensions;
int64 batch = 1;
for (int i = 0; i < ndims - 2; ++i) {
int64 a_size = a_shape.dimensions(i);
int64 b_size = b_shape.dimensions(i);
if (a_size != b_size) {
return InvalidArgument(
"Batch dimensions of arguments to TriangularSolve must be equal; "
"shapes were %s and %s.",
ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
}
batch_dimensions.push_back(a_size);
batch *= a_size;
}
if (ShapeUtil::GetDimension(a_shape, -1) !=
ShapeUtil::GetDimension(a_shape, -2)) {
return InvalidArgument(
"The 'a' argument to TriangularSolve must be a batched square matrix;"
" shape was: %s",
ShapeUtil::HumanString(a_shape));
}
const int64 m = ShapeUtil::GetDimension(b_shape, -2);
const int64 n = ShapeUtil::GetDimension(b_shape, -1);
if ((left_side ? m : n) != ShapeUtil::GetDimension(a_shape, -1)) {
return InvalidArgument(
"Arguments to TriangularSolve have incompatible matrix shapes %s and "
"%s",
ShapeUtil::HumanString(a_shape), ShapeUtil::HumanString(b_shape));
}
int64 a_size = ShapeUtil::GetDimension(a_shape, -1);
if (ShapeUtil::IsZeroElementArray(b_shape)) {
// The output has the same shape as 'b', and since the output has zero
// elements, any such array will do.
return b;
}
// Degenerate case: 1x1 matrices.
if (a_size == 1) {
return unit_diagonal ? b : Div(b, MaybeConjugate(a, conjugate_a));
}
// Prefer the direct implementation whenever there is a nontrivial batch
// dimension and the matrix is very small.
if (batch > block_size_ / 16 && a_size < block_size_ / 4) {
return SolveDirectly(a, b, left_side, lower, transpose_a, conjugate_a,
unit_diagonal, precision);
} else {
return SolveByInvertingDiagonalBlocks(a, b, left_side, lower, transpose_a,
conjugate_a, unit_diagonal,
precision);
}
});
}
TriangularSolveExpander::TriangularSolveExpander(int64 block_size)
: block_size_(block_size) {
CHECK_GE(block_size_, 1);
}
bool TriangularSolveExpander::InstructionMatchesPattern(
HloInstruction* instruction) {
return instruction->opcode() == HloOpcode::kTriangularSolve;
}
StatusOr<HloInstruction*> TriangularSolveExpander::ExpandInstruction(
HloInstruction* instruction) {
const TriangularSolveOptions& options =
instruction->triangular_solve_options();
const string name = absl::StrFormat(
"xla.triangular_solve_%s_%s_%s_%s_%s_%s",
instruction->operand(0)->shape().ToString(),
instruction->operand(1)->shape().ToString(),
options.left_side() ? "left" : "right",
options.lower() ? "lower" : "upper",
TriangularSolveOptions_Transpose_Name(options.transpose_a()),
options.unit_diagonal() ? "unit" : "nonunit");
HloModule* module = instruction->parent()->parent();
HloComputation*& computation =
computation_cache_.emplace(name, nullptr).first->second;
if (!computation) {
// Builds a new expansion.
//
// We do something unusual here: we build the computation using the
// XlaBuilder API, which is nominally an XLA client API. We do this because
// the external APIs for building complicated computations (XlaBuilder)
// are much more ergonomic than the internal ones. As it turns out,
// XlaBuilder isn't really a client API—what it does is build a
// HloModuleProto protocol buffer, that we can then deserialize and clone
// into our HloModule. Ideally we would avoid the protocol buffer step;
// that is left as an exercise for future work.
XlaBuilder builder(name);
XlaOp a = Parameter(&builder, 0, instruction->operand(0)->shape(), "a");
XlaOp b = Parameter(&builder, 1, instruction->operand(1)->shape(), "b");
bool transpose_a =
options.transpose_a() != TriangularSolveOptions::NO_TRANSPOSE;
bool conjugate_a = options.transpose_a() == TriangularSolveOptions::ADJOINT;
BuildTriangularSolve(a, b, options.left_side(), options.lower(),
transpose_a, conjugate_a, options.unit_diagonal(),
/*block_size=*/block_size_,
/*precision=*/PrecisionConfig::HIGHEST);
TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
xla_computation.GetProgramShape());
HloModuleConfig config(program_shape);
TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
xla_computation.proto(), config));
HloCloneContext context(module);
computation =
module->DeepCloneComputation(new_module->entry_computation(), &context);
}
return instruction->parent()->AddInstruction(HloInstruction::CreateCall(
instruction->shape(), instruction->operands(), computation));
}
} // namespace xla