From 0446d42b70c74eee48a7d7f5314196236c7967ff Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Fri, 5 Feb 2021 11:13:27 -0800 Subject: [PATCH] [XLA:GPU] Migrate TriangularSolve thunk emission to use MLIR PiperOrigin-RevId: 355887905 Change-Id: If996edcb622d92f425947c88f74290c44a792bb0 --- .../compiler/mlir/xla/attribute_exporter.cc | 21 ++++++ .../compiler/mlir/xla/attribute_exporter.h | 3 + .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 8 +- .../mlir/xla/tests/translate/export.mlir | 11 +++ .../xla/service/gpu/ir_emitter_unnested.cc | 74 ++++++++++++++++++- .../xla/service/gpu/ir_emitter_unnested.h | 1 + .../compiler/xla/service/gpu/thunk_emitter.cc | 60 --------------- .../compiler/xla/service/gpu/thunk_emitter.h | 3 - 8 files changed, 110 insertions(+), 71 deletions(-) diff --git a/tensorflow/compiler/mlir/xla/attribute_exporter.cc b/tensorflow/compiler/mlir/xla/attribute_exporter.cc index 4e5f24c7363..49c3562926a 100644 --- a/tensorflow/compiler/mlir/xla/attribute_exporter.cc +++ b/tensorflow/compiler/mlir/xla/attribute_exporter.cc @@ -152,4 +152,25 @@ StatusOr ConvertFftType(llvm::StringRef type_string) { } } +StatusOr ConvertTranspose( + llvm::StringRef transpose_string) { + llvm::Optional transpose = + mlir::mhlo::symbolizeTranspose(transpose_string); + if (!transpose) + return InvalidArgument("Unknown transpose type %s", transpose_string.str()); + + switch (*transpose) { + case mlir::mhlo::Transpose::NO_TRANSPOSE: + return TriangularSolveOptions::NO_TRANSPOSE; + case mlir::mhlo::Transpose::TRANSPOSE: + return TriangularSolveOptions::TRANSPOSE; + case mlir::mhlo::Transpose::ADJOINT: + return TriangularSolveOptions::ADJOINT; + case mlir::mhlo::Transpose::TRANSPOSE_INVALID: + return TriangularSolveOptions::TRANSPOSE_INVALID; + default: + return InvalidArgument("Unknown transpose enum value #%d", *transpose); + } +} + } // namespace xla diff --git a/tensorflow/compiler/mlir/xla/attribute_exporter.h b/tensorflow/compiler/mlir/xla/attribute_exporter.h index db9c8b2d208..df6494551f3 100644 --- a/tensorflow/compiler/mlir/xla/attribute_exporter.h +++ b/tensorflow/compiler/mlir/xla/attribute_exporter.h @@ -41,5 +41,8 @@ StatusOr>> ConvertNx2Attribute( llvm::Optional optional_attr); StatusOr ConvertFftType(llvm::StringRef type_string); +StatusOr ConvertTranspose( + llvm::StringRef transpose_string); + } // namespace xla #endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_ diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index d32d949355b..fafcc643ffd 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -181,13 +181,7 @@ static std::vector Convert_replica_groups( // Converts StringRef to xla Transpose enum. static xla::TriangularSolveOptions::Transpose Convert_transpose_a( llvm::StringRef transpose_str) { - xla::TriangularSolveOptions::Transpose transpose_enum; - // Illegal tanspose string would be caught by the verifier, so - // 'Transpose_Parse' call below should never return false. - if (!xla::TriangularSolveOptions::Transpose_Parse(std::string(transpose_str), - &transpose_enum)) - return xla::TriangularSolveOptions::NO_TRANSPOSE; - return transpose_enum; + return xla::ConvertTranspose(transpose_str).ValueOrDie(); } #define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \ diff --git a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir index 11e44fa1d54..c8032732168 100644 --- a/tensorflow/compiler/mlir/xla/tests/translate/export.mlir +++ b/tensorflow/compiler/mlir/xla/tests/translate/export.mlir @@ -1183,3 +1183,14 @@ func @main(%arg: tensor<3x4xf32>) -> tensor<3x4x1xf32> { %0 = "mhlo.bitcast"(%arg) : (tensor<3x4xf32>) -> tensor<3x4x1xf32> return %0 : tensor<3x4x1xf32> } + +// ----- + +// CHECK: HloModule +func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4xf32> { +// CHECK: %[[ARG0:.*]] = f32[4,4] parameter(0) +// CHECK: %[[ARG1:.*]] = f32[3,4] parameter(1) +// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] triangular-solve(f32[4,4] %[[ARG0]], f32[3,4] %[[ARG1]]), lower=true, transpose_a=NO_TRANSPOSE + %0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = false, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = false} : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> + return %0: tensor<3x4xf32> +} diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index c77e247485b..6c285de686f 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -43,6 +43,7 @@ limitations under the License. #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Verifier.h" // from @llvm-project @@ -90,6 +91,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" #include "tensorflow/compiler/xla/service/gpu/target_util.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h" #include "tensorflow/compiler/xla/service/gpu/while_thunk.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" @@ -1585,7 +1587,77 @@ Status IrEmitterUnnested::EmitFftThunkFromMlir(MlirEmitterInput input) { } Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) { - return ThunkEmitter(this).HandleTriangularSolve(hlo); + TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo)); + return EmitTriangularSolveFromMlir(input); +} + +Status IrEmitterUnnested::EmitTriangularSolveFromMlir(MlirEmitterInput input) { + auto triangular_solve_op = + mlir::cast(input.op); + auto has_fortran_layout = [](mlir::DenseIntElementsAttr layout_attr) { + int64_t n = layout_attr.getNumElements(); + return layout_attr.getValue({0}) == n - 2 && + layout_attr.getValue({1}) == n - 1; + }; + TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_a())); + TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_b())); + TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_output())); + + const Shape b_shape = TypeToShape(triangular_solve_op.b().getType()); + + const Shape output_shape = + TypeToShape(triangular_solve_op.output().getType()); + + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice, + GetAllocationSliceForMlir(triangular_solve_op.a())); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_slice, + GetAllocationSliceForMlir(triangular_solve_op.b())); + TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice, + GetAllocationSliceForMlir(triangular_solve_op.output())); + TF_ASSIGN_OR_RETURN(TriangularSolveOptions_Transpose transpose_a, + ConvertTranspose(triangular_solve_op.transpose_a())); + + std::vector> thunks; + + // Triangular solve is in-place on 'b', so copy 'b' to the output if they + // aren't the same buffer. + if (b_slice != output_slice) { + thunks.push_back(absl::make_unique( + Thunk::ThunkInfo(), + /*source_address=*/b_slice, + /*destination_buffer=*/output_slice, + /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape))); + } + + int64 m = b_shape.dimensions(b_shape.rank() - 2); + int64 n = b_shape.dimensions(b_shape.rank() - 1); + int64 batch_size = std::accumulate(b_shape.dimensions().begin(), + b_shape.dimensions().end() - 2, int64{1}, + [](int64 a, int64 b) { return a * b; }); + int64 elem_size = + ShapeUtil::ByteSizeOfPrimitiveType(output_shape.element_type()); + int64 a_batch_stride = + triangular_solve_op.left_side() ? m * m * elem_size : n * n * elem_size; + int64 b_batch_stride = m * n * elem_size; + TriangularSolveOptions options; + options.set_left_side(triangular_solve_op.left_side()); + options.set_lower(triangular_solve_op.lower()); + options.set_unit_diagonal(triangular_solve_op.unit_diagonal()); + options.set_transpose_a(transpose_a); + thunks.push_back(absl::make_unique( + input.thunk_info, options, + /*a_input_buffer=*/a_slice, + /*b_input_buffer=*/output_slice, output_shape.element_type(), batch_size, + m, n, a_batch_stride, b_batch_stride)); + + // Elide the sequential thunk if there's no copy. + if (thunks.size() == 1) { + AddThunkToThunkSequence(std::move(thunks[0])); + } else { + AddThunkToThunkSequence(absl::make_unique( + input.thunk_info, std::move(thunks))); + } + return Status::OK(); } // Convert the following form of fusion region: diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 3fe0df27823..ce1eac934e9 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -201,6 +201,7 @@ class IrEmitterUnnested : public IrEmitter, Status HandleSort(HloInstruction* sort) override; Status EmitSortFromMlir(MlirEmitterInput mlir_input); Status HandleTriangularSolve(HloInstruction* hlo) override; + Status EmitTriangularSolveFromMlir(MlirEmitterInput mlir_input); template Status EmitNcclThunkFromMlir(MlirEmitterInput mlir_input); diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc index 9201c08362f..7624f2b25d5 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.cc @@ -21,35 +21,11 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" namespace xla { namespace gpu { -std::unique_ptr ThunkEmitter::BuildTriangularSolveThunk( - const HloInstruction* inst) { - const HloInstruction* a = inst->operand(0); - const HloInstruction* b = inst->operand(1); - int64 m = b->shape().dimensions(b->shape().rank() - 2); - int64 n = b->shape().dimensions(b->shape().rank() - 1); - int64 batch_size = std::accumulate( - b->shape().dimensions().begin(), b->shape().dimensions().end() - 2, - int64{1}, [](int64 a, int64 b) { return a * b; }); - int64 elem_size = - ShapeUtil::ByteSizeOfPrimitiveType(inst->shape().element_type()); - int64 a_batch_stride = inst->triangular_solve_options().left_side() - ? m * m * elem_size - : n * n * elem_size; - int64 b_batch_stride = m * n * elem_size; - return absl::make_unique( - context_->GetThunkInfo(inst), inst->triangular_solve_options(), - /*a_input_buffer=*/GetAllocationSlice(*a), - /*b_input_buffer=*/GetAllocationSlice(*inst), - inst->shape().element_type(), batch_size, m, n, a_batch_stride, - b_batch_stride); -} - std::unique_ptr ThunkEmitter::BuildGemmThunk( const HloInstruction* inst) { GpuGemmConfig config = GetGpuGemmConfig(inst); @@ -88,42 +64,6 @@ std::unique_ptr ThunkEmitter::BuildGemmThunk( /*implements_whole_instruction=*/true); } -Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) { - auto has_fortran_layout = [](const Layout& layout) { - int n = layout.minor_to_major_size(); - return layout.minor_to_major(0) == n - 2 && - layout.minor_to_major(1) == n - 1; - }; - TF_RET_CHECK(has_fortran_layout(hlo->operand(0)->shape().layout())); - TF_RET_CHECK(has_fortran_layout(hlo->operand(1)->shape().layout())); - TF_RET_CHECK(has_fortran_layout(hlo->shape().layout())); - - std::vector> thunks; - - // Triangular solve is in-place on 'b', so copy 'b' to the output if they - // aren't the same buffer. - auto operand_buffer = GetAllocationSlice(*hlo->operand(1)); - auto destination_buffer = GetAllocationSlice(*hlo); - if (operand_buffer != destination_buffer) { - thunks.push_back(absl::make_unique( - context_->GetThunkInfo(hlo), - /*source_address=*/operand_buffer, - /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape()))); - } - - thunks.push_back(BuildTriangularSolveThunk(hlo)); - - // Elide the sequential thunk if there's no copy. - if (thunks.size() == 1) { - AddThunkToThunkSequence(std::move(thunks[0])); - } else { - AddThunkToThunkSequence(absl::make_unique( - context_->GetThunkInfo(hlo), std::move(thunks))); - } - return Status::OK(); -} - Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo( const HloInstruction* hlo) const { CHECK(hlo); diff --git a/tensorflow/compiler/xla/service/gpu/thunk_emitter.h b/tensorflow/compiler/xla/service/gpu/thunk_emitter.h index 4669d5c7c8d..12c237c448f 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_emitter.h @@ -69,9 +69,6 @@ class ThunkEmitter { // Returns a CholeskyThunk that calls cuSolver to implement `inst`. std::unique_ptr BuildCholeskyThunk(const HloInstruction* inst); - // Returns a TriangularSolveThunk that calls cuBlas to implement `inst`. - std::unique_ptr BuildTriangularSolveThunk(const HloInstruction* inst); - // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs // to make sure `inst` outlives the lifetime of the returned Thunk object. std::unique_ptr BuildGemmThunk(const HloInstruction* inst);