[XLA:GPU] Migrate TriangularSolve thunk emission to use MLIR
PiperOrigin-RevId: 355887905 Change-Id: If996edcb622d92f425947c88f74290c44a792bb0
This commit is contained in:
parent
3080eef5b0
commit
0446d42b70
@ -152,4 +152,25 @@ StatusOr<FftType> ConvertFftType(llvm::StringRef type_string) {
|
||||
}
|
||||
}
|
||||
|
||||
StatusOr<TriangularSolveOptions::Transpose> ConvertTranspose(
|
||||
llvm::StringRef transpose_string) {
|
||||
llvm::Optional<mlir::mhlo::Transpose> transpose =
|
||||
mlir::mhlo::symbolizeTranspose(transpose_string);
|
||||
if (!transpose)
|
||||
return InvalidArgument("Unknown transpose type %s", transpose_string.str());
|
||||
|
||||
switch (*transpose) {
|
||||
case mlir::mhlo::Transpose::NO_TRANSPOSE:
|
||||
return TriangularSolveOptions::NO_TRANSPOSE;
|
||||
case mlir::mhlo::Transpose::TRANSPOSE:
|
||||
return TriangularSolveOptions::TRANSPOSE;
|
||||
case mlir::mhlo::Transpose::ADJOINT:
|
||||
return TriangularSolveOptions::ADJOINT;
|
||||
case mlir::mhlo::Transpose::TRANSPOSE_INVALID:
|
||||
return TriangularSolveOptions::TRANSPOSE_INVALID;
|
||||
default:
|
||||
return InvalidArgument("Unknown transpose enum value #%d", *transpose);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -41,5 +41,8 @@ StatusOr<std::vector<std::pair<int64, int64>>> ConvertNx2Attribute(
|
||||
llvm::Optional<mlir::DenseIntElementsAttr> optional_attr);
|
||||
|
||||
StatusOr<FftType> ConvertFftType(llvm::StringRef type_string);
|
||||
StatusOr<TriangularSolveOptions::Transpose> ConvertTranspose(
|
||||
llvm::StringRef transpose_string);
|
||||
|
||||
} // namespace xla
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_ATTRIBUTE_EXPORTER_H_
|
||||
|
@ -181,13 +181,7 @@ static std::vector<xla::ReplicaGroup> Convert_replica_groups(
|
||||
// Converts StringRef to xla Transpose enum.
|
||||
static xla::TriangularSolveOptions::Transpose Convert_transpose_a(
|
||||
llvm::StringRef transpose_str) {
|
||||
xla::TriangularSolveOptions::Transpose transpose_enum;
|
||||
// Illegal tanspose string would be caught by the verifier, so
|
||||
// 'Transpose_Parse' call below should never return false.
|
||||
if (!xla::TriangularSolveOptions::Transpose_Parse(std::string(transpose_str),
|
||||
&transpose_enum))
|
||||
return xla::TriangularSolveOptions::NO_TRANSPOSE;
|
||||
return transpose_enum;
|
||||
return xla::ConvertTranspose(transpose_str).ValueOrDie();
|
||||
}
|
||||
|
||||
#define I64_ELEMENTS_ATTR_TO_VECTOR(attribute) \
|
||||
|
@ -1183,3 +1183,14 @@ func @main(%arg: tensor<3x4xf32>) -> tensor<3x4x1xf32> {
|
||||
%0 = "mhlo.bitcast"(%arg) : (tensor<3x4xf32>) -> tensor<3x4x1xf32>
|
||||
return %0 : tensor<3x4x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: HloModule
|
||||
func @main(%arg0: tensor<4x4xf32>, %arg1: tensor<3x4xf32>) -> tensor<3x4xf32> {
|
||||
// CHECK: %[[ARG0:.*]] = f32[4,4] parameter(0)
|
||||
// CHECK: %[[ARG1:.*]] = f32[3,4] parameter(1)
|
||||
// CHECK: ROOT %[[RESULT:.*]] = f32[3,4] triangular-solve(f32[4,4] %[[ARG0]], f32[3,4] %[[ARG1]]), lower=true, transpose_a=NO_TRANSPOSE
|
||||
%0 = "mhlo.triangular_solve"(%arg0, %arg1) {left_side = false, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = false} : (tensor<4x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32>
|
||||
return %0: tensor<3x4xf32>
|
||||
}
|
||||
|
@ -43,6 +43,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/Verifier.h" // from @llvm-project
|
||||
@ -90,6 +91,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/target_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
@ -1585,7 +1587,77 @@ Status IrEmitterUnnested::EmitFftThunkFromMlir(MlirEmitterInput input) {
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) {
|
||||
return ThunkEmitter(this).HandleTriangularSolve(hlo);
|
||||
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
|
||||
return EmitTriangularSolveFromMlir(input);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitTriangularSolveFromMlir(MlirEmitterInput input) {
|
||||
auto triangular_solve_op =
|
||||
mlir::cast<mlir::lmhlo::TriangularSolveOp>(input.op);
|
||||
auto has_fortran_layout = [](mlir::DenseIntElementsAttr layout_attr) {
|
||||
int64_t n = layout_attr.getNumElements();
|
||||
return layout_attr.getValue<int64_t>({0}) == n - 2 &&
|
||||
layout_attr.getValue<int64_t>({1}) == n - 1;
|
||||
};
|
||||
TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_a()));
|
||||
TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_b()));
|
||||
TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_output()));
|
||||
|
||||
const Shape b_shape = TypeToShape(triangular_solve_op.b().getType());
|
||||
|
||||
const Shape output_shape =
|
||||
TypeToShape(triangular_solve_op.output().getType());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice,
|
||||
GetAllocationSliceForMlir(triangular_solve_op.a()));
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_slice,
|
||||
GetAllocationSliceForMlir(triangular_solve_op.b()));
|
||||
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
|
||||
GetAllocationSliceForMlir(triangular_solve_op.output()));
|
||||
TF_ASSIGN_OR_RETURN(TriangularSolveOptions_Transpose transpose_a,
|
||||
ConvertTranspose(triangular_solve_op.transpose_a()));
|
||||
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
|
||||
// Triangular solve is in-place on 'b', so copy 'b' to the output if they
|
||||
// aren't the same buffer.
|
||||
if (b_slice != output_slice) {
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_address=*/b_slice,
|
||||
/*destination_buffer=*/output_slice,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(b_shape)));
|
||||
}
|
||||
|
||||
int64 m = b_shape.dimensions(b_shape.rank() - 2);
|
||||
int64 n = b_shape.dimensions(b_shape.rank() - 1);
|
||||
int64 batch_size = std::accumulate(b_shape.dimensions().begin(),
|
||||
b_shape.dimensions().end() - 2, int64{1},
|
||||
[](int64 a, int64 b) { return a * b; });
|
||||
int64 elem_size =
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(output_shape.element_type());
|
||||
int64 a_batch_stride =
|
||||
triangular_solve_op.left_side() ? m * m * elem_size : n * n * elem_size;
|
||||
int64 b_batch_stride = m * n * elem_size;
|
||||
TriangularSolveOptions options;
|
||||
options.set_left_side(triangular_solve_op.left_side());
|
||||
options.set_lower(triangular_solve_op.lower());
|
||||
options.set_unit_diagonal(triangular_solve_op.unit_diagonal());
|
||||
options.set_transpose_a(transpose_a);
|
||||
thunks.push_back(absl::make_unique<TriangularSolveThunk>(
|
||||
input.thunk_info, options,
|
||||
/*a_input_buffer=*/a_slice,
|
||||
/*b_input_buffer=*/output_slice, output_shape.element_type(), batch_size,
|
||||
m, n, a_batch_stride, b_batch_stride));
|
||||
|
||||
// Elide the sequential thunk if there's no copy.
|
||||
if (thunks.size() == 1) {
|
||||
AddThunkToThunkSequence(std::move(thunks[0]));
|
||||
} else {
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
input.thunk_info, std::move(thunks)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Convert the following form of fusion region:
|
||||
|
@ -201,6 +201,7 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandleSort(HloInstruction* sort) override;
|
||||
Status EmitSortFromMlir(MlirEmitterInput mlir_input);
|
||||
Status HandleTriangularSolve(HloInstruction* hlo) override;
|
||||
Status EmitTriangularSolveFromMlir(MlirEmitterInput mlir_input);
|
||||
|
||||
template <typename NcclThunkType, typename OpTy>
|
||||
Status EmitNcclThunkFromMlir(MlirEmitterInput mlir_input);
|
||||
|
@ -21,35 +21,11 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildTriangularSolveThunk(
|
||||
const HloInstruction* inst) {
|
||||
const HloInstruction* a = inst->operand(0);
|
||||
const HloInstruction* b = inst->operand(1);
|
||||
int64 m = b->shape().dimensions(b->shape().rank() - 2);
|
||||
int64 n = b->shape().dimensions(b->shape().rank() - 1);
|
||||
int64 batch_size = std::accumulate(
|
||||
b->shape().dimensions().begin(), b->shape().dimensions().end() - 2,
|
||||
int64{1}, [](int64 a, int64 b) { return a * b; });
|
||||
int64 elem_size =
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(inst->shape().element_type());
|
||||
int64 a_batch_stride = inst->triangular_solve_options().left_side()
|
||||
? m * m * elem_size
|
||||
: n * n * elem_size;
|
||||
int64 b_batch_stride = m * n * elem_size;
|
||||
return absl::make_unique<TriangularSolveThunk>(
|
||||
context_->GetThunkInfo(inst), inst->triangular_solve_options(),
|
||||
/*a_input_buffer=*/GetAllocationSlice(*a),
|
||||
/*b_input_buffer=*/GetAllocationSlice(*inst),
|
||||
inst->shape().element_type(), batch_size, m, n, a_batch_stride,
|
||||
b_batch_stride);
|
||||
}
|
||||
|
||||
std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
|
||||
const HloInstruction* inst) {
|
||||
GpuGemmConfig config = GetGpuGemmConfig(inst);
|
||||
@ -88,42 +64,6 @@ std::unique_ptr<Thunk> ThunkEmitter::BuildGemmThunk(
|
||||
/*implements_whole_instruction=*/true);
|
||||
}
|
||||
|
||||
Status ThunkEmitter::HandleTriangularSolve(HloInstruction* hlo) {
|
||||
auto has_fortran_layout = [](const Layout& layout) {
|
||||
int n = layout.minor_to_major_size();
|
||||
return layout.minor_to_major(0) == n - 2 &&
|
||||
layout.minor_to_major(1) == n - 1;
|
||||
};
|
||||
TF_RET_CHECK(has_fortran_layout(hlo->operand(0)->shape().layout()));
|
||||
TF_RET_CHECK(has_fortran_layout(hlo->operand(1)->shape().layout()));
|
||||
TF_RET_CHECK(has_fortran_layout(hlo->shape().layout()));
|
||||
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
|
||||
// Triangular solve is in-place on 'b', so copy 'b' to the output if they
|
||||
// aren't the same buffer.
|
||||
auto operand_buffer = GetAllocationSlice(*hlo->operand(1));
|
||||
auto destination_buffer = GetAllocationSlice(*hlo);
|
||||
if (operand_buffer != destination_buffer) {
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
context_->GetThunkInfo(hlo),
|
||||
/*source_address=*/operand_buffer,
|
||||
/*destination_buffer=*/destination_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(hlo->operand(1)->shape())));
|
||||
}
|
||||
|
||||
thunks.push_back(BuildTriangularSolveThunk(hlo));
|
||||
|
||||
// Elide the sequential thunk if there's no copy.
|
||||
if (thunks.size() == 1) {
|
||||
AddThunkToThunkSequence(std::move(thunks[0]));
|
||||
} else {
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
context_->GetThunkInfo(hlo), std::move(thunks)));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Thunk::ThunkInfo ThunkEmitter::EmissionContext::GetThunkInfo(
|
||||
const HloInstruction* hlo) const {
|
||||
CHECK(hlo);
|
||||
|
@ -69,9 +69,6 @@ class ThunkEmitter {
|
||||
// Returns a CholeskyThunk that calls cuSolver to implement `inst`.
|
||||
std::unique_ptr<Thunk> BuildCholeskyThunk(const HloInstruction* inst);
|
||||
|
||||
// Returns a TriangularSolveThunk that calls cuBlas to implement `inst`.
|
||||
std::unique_ptr<Thunk> BuildTriangularSolveThunk(const HloInstruction* inst);
|
||||
|
||||
// Returns a GemmThunk that calls gemm to implement `inst`. The caller needs
|
||||
// to make sure `inst` outlives the lifetime of the returned Thunk object.
|
||||
std::unique_ptr<Thunk> BuildGemmThunk(const HloInstruction* inst);
|
||||
|
Loading…
x
Reference in New Issue
Block a user