[XLA:GPU] Migrate TriangularSolve thunk emission to use MLIR

PiperOrigin-RevId: 355887905
Change-Id: If996edcb622d92f425947c88f74290c44a792bb0
This commit is contained in:
Rahul Joshi 2021-02-05 11:13:27 -08:00 committed by TensorFlower Gardener
parent 3080eef5b0
commit 0446d42b70
8 changed files with 110 additions and 71 deletions

View File

@ -152,4 +152,25 @@ StatusOr<FftType> ConvertFftType(llvm::StringRef type_string) {
}
}
StatusOr<TriangularSolveOptions::Transpose> ConvertTranspose(
llvm::StringRef transpose_string) {
llvm::Optional<mlir::mhlo::Transpose> transpose =
mlir::mhlo::symbolizeTranspose(transpose_string);
if (!transpose)
return InvalidArgument("Unknown transpose type %s", transpose_string.str());
switch (*transpose) {
case mlir::mhlo::Transpose::NO_TRANSPOSE:
return TriangularSolveOptions::NO_TRANSPOSE;
case mlir::mhlo::Transpose::TRANSPOSE:
return TriangularSolveOptions::TRANSPOSE;
case mlir::mhlo::Transpose::ADJOINT:
return TriangularSolveOptions::ADJOINT;
case mlir::mhlo::Transpose::TRANSPOSE_INVALID:
return TriangularSolveOptions::TRANSPOSE_INVALID;
default:
return InvalidArgument("Unknown transpose enum value #%d", *transpose);
}
}
} // namespace xla

View File

@ -41,5 +41,8 @@ StatusOr<std::vector<std::pair<int64, int64>>> ConvertNx2Attribute(
llvm::Optional<mlir::DenseIntElementsAttr> optional_attr);
StatusOr<FftType> ConvertFftType(llvm::StringRef type_string);
StatusOr<TriangularSolveOptions::Transpose> ConvertTranspose(
llvm::StringRef transpose_string);
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_

View File

@ -181,13 +181,7 @@ static std::vector<xla::ReplicaGroup> Convert_replica_groups(
// Converts StringRef to xla Transpose enum.
static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
llvm::StringRef transpose_str) {
xla::TriangularSolveOptions::Transpose transpose_enum;
// Illegal tanspose string would be caught by the verifier, so
// 'Transpose_Parse' call below should never return false.
if (!xla::TriangularSolveOptions::Transpose_Parse(std::string(transpose_str),
&transpose_enum))
return xla::TriangularSolveOptions::NO_TRANSPOSE;
return transpose_enum;
return xla::ConvertTranspose(transpose_str).ValueOrDie();
}
#define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \

View File

@ -1183,3 +1183,14 @@ func @main(%arg: tensor<3x4xf32>) -> tensor<3x4x1xf32> {
%0 = "mhlo.bitcast"(%arg) : (tensor<3x4xf32>) -> tensor<3x4x1xf32>
return %0 : tensor<3x4x1xf32>
}
// -----
// CHECK: HloModule
func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4xf32> {
// CHECK: %[[ARG0:.*]] = f32[4,4] parameter(0)
// CHECK: %[[ARG1:.*]] = f32[3,4] parameter(1)
// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] triangular-solve(f32[4,4] %[[ARG0]], f32[3,4] %[[ARG1]]), lower=true, transpose_a=NO_TRANSPOSE
%0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = false, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = false} : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32>
return %0: tensor<3x4xf32>
}

View File

@ -43,6 +43,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/Verifier.h" // from @llvm-project
@ -90,6 +91,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/target_util.h"
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -1585,7 +1587,77 @@ Status IrEmitterUnnested::EmitFftThunkFromMlir(MlirEmitterInput input) {
}
Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) {
return ThunkEmitter(this).HandleTriangularSolve(hlo);
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
return EmitTriangularSolveFromMlir(input);
}
Status IrEmitterUnnested::EmitTriangularSolveFromMlir(MlirEmitterInput input) {
auto triangular_solve_op =
mlir::cast<mlir::lmhlo::TriangularSolveOp>(input.op);
auto has_fortran_layout = [](mlir::DenseIntElementsAttr layout_attr) {
int64_t n = layout_attr.getNumElements();
return layout_attr.getValue<int64_t>({0}) == n - 2 &&
layout_attr.getValue<int64_t>({1}) == n - 1;
};
TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_a()));
TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_b()));
TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_output()));
const Shape b_shape = TypeToShape(triangular_solve_op.b().getType());
const Shape output_shape =
TypeToShape(triangular_solve_op.output().getType());
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice,
GetAllocationSliceForMlir(triangular_solve_op.a()));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_slice,
GetAllocationSliceForMlir(triangular_solve_op.b()));
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
GetAllocationSliceForMlir(triangular_solve_op.output()));
TF_ASSIGN_OR_RETURN(TriangularSolveOptions_Transpose transpose_a,
ConvertTranspose(triangular_solve_op.transpose_a()));
std::vector<std::unique_ptr<Thunk>> thunks;
// Triangular solve is in-place on 'b', so copy 'b' to the output if they
// aren't the same buffer.
if (b_slice != output_slice) {
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_address=*/b_slice,
/*destination_buffer=*/output_slice,
/*mem_size=*/ShapeUtil::ByteSizeOf(b_shape)));
}
int64 m = b_shape.dimensions(b_shape.rank() - 2);
int64 n = b_shape.dimensions(b_shape.rank() - 1);
int64 batch_size = std::accumulate(b_shape.dimensions().begin(),
b_shape.dimensions().end() - 2, int64{1},
[](int64 a, int64 b) { return a * b; });
int64 elem_size =
ShapeUtil::ByteSizeOfPrimitiveType(output_shape.element_type());
int64 a_batch_stride =
triangular_solve_op.left_side() ? m * m * elem_size : n * n * elem_size;
int64 b_batch_stride = m * n * elem_size;
TriangularSolveOptions options;
options.set_left_side(triangular_solve_op.left_side());
options.set_lower(triangular_solve_op.lower());
options.set_unit_diagonal(triangular_solve_op.unit_diagonal());
options.set_transpose_a(transpose_a);
thunks.push_back(absl::make_unique<TriangularSolveThunk>(
input.thunk_info, options,
/*a_input_buffer=*/a_slice,
/*b_input_buffer=*/output_slice, output_shape.element_type(), batch_size,
m, n, a_batch_stride, b_batch_stride));
// Elide the sequential thunk if there's no copy.
if (thunks.size() == 1) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
input.thunk_info, std::move(thunks)));
}
return Status::OK();
}
// Convert the following form of fusion region:

View File

@ -201,6 +201,7 @@ class IrEmitterUnnested : public IrEmitter,
Status HandleSort(HloInstruction* sort) override;
Status EmitSortFromMlir(MlirEmitterInput mlir_input);
Status HandleTriangularSolve(HloInstruction* hlo) override;
Status EmitTriangularSolveFromMlir(MlirEmitterInput mlir_input);
template <typename NcclThunkType, typename OpTy>
Status EmitNcclThunkFromMlir(MlirEmitterInput mlir_input);

View File

@ -21,35 +21,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
namespace xla {
namespace gpu {
std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk(
const HloInstruction* inst) {
const HloInstruction* a = inst->operand(0);
const HloInstruction* b = inst->operand(1);
int64 m = b->shape().dimensions(b->shape().rank() - 2);
int64 n = b->shape().dimensions(b->shape().rank() - 1);
int64 batch_size = std::accumulate(
b->shape().dimensions().begin(), b->shape().dimensions().end() - 2,
int64{1}, [](int64 a, int64 b) { return a * b; });
int64 elem_size =
ShapeUtil::ByteSizeOfPrimitiveType(inst->shape().element_type());
int64 a_batch_stride = inst->triangular_solve_options().left_side()
? m * m * elem_size
: n * n * elem_size;
int64 b_batch_stride = m * n * elem_size;
return absl::make_unique<TriangularSolveThunk>(
context_->GetThunkInfo(inst), inst->triangular_solve_options(),
/*a_input_buffer=*/GetAllocationSlice(*a),
/*b_input_buffer=*/GetAllocationSlice(*inst),
inst->shape().element_type(), batch_size, m, n, a_batch_stride,
b_batch_stride);
}
std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
const HloInstruction* inst) {
GpuGemmConfig config = GetGpuGemmConfig(inst);
@ -88,42 +64,6 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
/*implements_whole_instruction=*/true);
}
Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
auto has_fortran_layout = [](const Layout& layout) {
int n = layout.minor_to_major_size();
return layout.minor_to_major(0) == n - 2 &&
layout.minor_to_major(1) == n - 1;
};
TF_RET_CHECK(has_fortran_layout(hlo->operand(0)->shape().layout()));
TF_RET_CHECK(has_fortran_layout(hlo->operand(1)->shape().layout()));
TF_RET_CHECK(has_fortran_layout(hlo->shape().layout()));
std::vector<std::unique_ptr<Thunk>> thunks;
// Triangular solve is in-place on 'b', so copy 'b' to the output if they
// aren't the same buffer.
auto operand_buffer = GetAllocationSlice(*hlo->operand(1));
auto destination_buffer = GetAllocationSlice(*hlo);
if (operand_buffer != destination_buffer) {
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
context_->GetThunkInfo(hlo),
/*source_address=*/operand_buffer,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape())));
}
thunks.push_back(BuildTriangularSolveThunk(hlo));
// Elide the sequential thunk if there's no copy.
if (thunks.size() == 1) {
AddThunkToThunkSequence(std::move(thunks[0]));
} else {
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
context_->GetThunkInfo(hlo), std::move(thunks)));
}
return Status::OK();
}
Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo(
const HloInstruction* hlo) const {
CHECK(hlo);

View File

@ -69,9 +69,6 @@ class ThunkEmitter {
// Returns a CholeskyThunk that calls cuSolver to implement `inst`.
std::unique_ptr<Thunk> BuildCholeskyThunk(const HloInstruction* inst);
// Returns a TriangularSolveThunk that calls cuBlas to implement `inst`.
std::unique_ptr<Thunk> BuildTriangularSolveThunk(const HloInstruction* inst);
// Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
// to make sure `inst` outlives the lifetime of the returned Thunk object.
std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);