[XLA/GPU] Migrate Copy emitters to take LMHLO.

PiperOrigin-RevId: 342215820
Change-Id: Ibf9a16c30a5f62f8177bed98b000fe59e5192c93
This commit is contained in:
A. Unique TensorFlower 2020-11-13 00:56:23 -08:00 committed by TensorFlower Gardener
parent fce727c889
commit fbccb57c1f
5 changed files with 104 additions and 216 deletions

View File

@ -268,7 +268,6 @@ cc_library(
"//tensorflow/compiler/mlir:name_utils",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/xla:hlo_module_importer",
"//tensorflow/compiler/mlir/xla:hlo_utils",
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",

View File

@ -44,7 +44,6 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/utils/name_utils.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
@ -547,39 +546,6 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
return IrEmitter::DefaultAction(hlo);
}
Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) {
// Replace unnested op with a fused nested op.
//
// TODO(timshen): Ultimately this should be a pass. It's currently not a pass,
// because we don't have a fully functioning LMHLO graph yet.
mlir::Location loc = input.op->getLoc();
mlir::lmhlo::FusionOp fusion = nullptr;
Shape output_shape;
if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(input.op)) {
fusion = mlir::OpBuilder(copy).create<mlir::lmhlo::FusionOp>(
loc, llvm::ArrayRef<mlir::NamedAttribute>());
copy.getOperation()->moveBefore(&fusion.region().front().back());
mlir::OpBuilder b(copy);
auto operand = b.create<mlir::TensorLoadOp>(loc, copy.operand());
HloFunctionImporter::SetLayoutForMlir(
operand, TypeToShape(copy.operand().getType()));
auto fused_copy = b.create<mlir::mhlo::CopyOp>(loc, operand);
output_shape = TypeToShape(copy.output().getType());
HloFunctionImporter::SetLayoutForMlir(fused_copy, output_shape);
b.create<mlir::TensorStoreOp>(loc, fused_copy, copy.output());
copy.getOperation()->erase();
} else {
input.op->dump();
LOG(FATAL) << "Unimplemented default action for mlir op";
}
input.op = fusion;
auto ret = EmitLoopFusionFromMlir(
input, output_shape,
ComputeMaxUnrollFactor(output_shape, hlo_module_config_));
return ret;
}
Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
AddThunkToThunkSequence(
BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
@ -1184,10 +1150,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop)
<< ": " << fusion->ToString();
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fusion));
TF_ASSIGN_OR_RETURN(const bool matched_021,
CheckAndEmitHloWithTile021(input));
if (matched_021) {
if (CheckAndEmitHloWithTile021(fusion)) {
return Status::OK();
}
@ -1196,46 +1159,35 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
unroll_factor = ComputeMaxUnrollFactor(fusion);
}
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fusion));
return EmitLoopFusionFromMlir(input, fusion->shape(), unroll_factor);
}
Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(copy));
return EmitCopyForMlir(input);
}
Status IrEmitterUnnested::EmitCopyForMlir(MlirEmitterInput input) {
auto copy = mlir::cast<mlir::lmhlo::CopyOp>(input.op);
auto operand_shape = TypeToShape(copy.operand().getType());
auto output_shape = TypeToShape(copy.output().getType());
CHECK(ShapeUtil::Compatible(operand_shape, output_shape));
absl::Span<const BufferAllocation> allocations(
ir_emitter_context_->buffer_assignment().Allocations());
auto maybe_slice = GetAllocationSliceForMlir(copy.operand(), allocations);
if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) &&
maybe_slice.ok()) {
CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape()));
const BufferAssignment& buffer_assignment =
ir_emitter_context_->buffer_assignment();
if (LayoutUtil::Equal(copy->operand(0)->shape().layout(),
copy->shape().layout()) &&
buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) {
// Copy the operand into the output if it's not the same buffer already.
auto operand_buffer = *maybe_slice;
auto destination_buffer =
*GetAllocationSliceForMlir(copy.output(), allocations);
auto operand_buffer = GetAllocationSlice(*copy->operand(0));
auto destination_buffer = GetAllocationSlice(*copy);
if (operand_buffer != destination_buffer) {
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
input.thunk_info,
GetThunkInfo(copy),
/*source_address=*/operand_buffer,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/
ByteSizeOf(operand_shape)));
ByteSizeOf(copy->operand(0)->shape())));
}
return Status::OK();
}
TF_ASSIGN_OR_RETURN(bool matched_021, CheckAndEmitHloWithTile021(input));
if (matched_021) {
if (CheckAndEmitHloWithTile021(copy)) {
return Status::OK();
}
return DefaultActionForMlir(input);
return IrEmitter::HandleCopy(copy);
}
Status IrEmitterUnnested::EmitExtraOutputsForReduce(
@ -2552,15 +2504,11 @@ static void GetFusionOperandsAndOutputs(mlir::lmhlo::FusionOp fusion,
std::vector<mlir::Value>* operands,
std::vector<mlir::Value>* outputs) {
fusion.region().walk([&](mlir::TensorLoadOp load) {
CHECK(load.memref().getParentRegion() != &fusion.region())
<< "TensorLoadOp shows should be only expected for accessing captured "
"memrefs.";
CHECK(load.memref().getParentRegion() != &fusion.region());
operands->push_back(load.memref());
});
fusion.region().walk([&](mlir::TensorStoreOp store) {
CHECK(store.memref().getParentRegion() != &fusion.region())
<< "TensorStoreOp shows should be only expected for accessing captured "
"memrefs.";
CHECK(store.memref().getParentRegion() != &fusion.region());
outputs->push_back(store.memref());
});
}
@ -3170,16 +3118,16 @@ void IrEmitterUnnested::EmitTile(
// dimensions that play the same role in the transpose.
// mapping_scheme: Kernel mapping scheme specifying the tiling
void IrEmitterUnnested::EmitTileElementForCopy(
const Shape& output_shape, const llvm_ir::IrArray& output_array,
const llvm_ir::IrArray::Index& index,
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
// TODO(jlebar): Add AA metadata to this load.
llvm::Instruction* load_from_shmem_buffer =
Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
"output_element");
llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo);
Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
output_shape.element_type(), mapping_scheme.GetDimsInElems());
hlo->shape().element_type(), mapping_scheme.GetDimsInElems());
// When the output_reduced_shape is a 0-2-1 transpose of the input shape,
// the 0-2-1 transpose is achieved through EmitWriteArrayElement.
output_array.CastToShape(output_reduced_shape, &b_)
@ -3217,19 +3165,14 @@ static IrArray::Index GetUnnormalizedIndex(
// the same role in the transpose.
// kernel_info: Other information to support the kernel code generation.
void IrEmitterUnnested::EmitTileElementForFusion(
mlir::lmhlo::FusionOp fusion,
absl::Span<const llvm_ir::IrArray> operand_arrays,
absl::Span<const llvm_ir::IrArray> output_arrays,
const llvm_ir::IrArray::Index& index,
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
const HloComputation* fused_computation =
*GetOrCreateSubComputationFromRegion(&fusion.region(),
/*is_fusion*/ true);
std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
GetNestedComputer());
FusedIrEmitter fused_emitter(&elem_emitter);
for (int i = 0; i < operand_arrays.size(); i++) {
for (int i = 0; i < hlo->operand_count(); i++) {
llvm_ir::ElementGenerator gen;
if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) {
gen = [this, param_tile_buffer, x_loc,
@ -3246,24 +3189,23 @@ void IrEmitterUnnested::EmitTileElementForFusion(
"tiled_buffer");
};
} else {
auto array = operand_arrays[i];
gen = [this, array](llvm_ir::IrArray::Index index) {
return array.EmitReadArrayElement(index, &b_);
const HloInstruction* operand = hlo->operand(i);
gen = [this, operand, hlo](llvm_ir::IrArray::Index index) {
return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
};
}
fused_emitter.BindGenerator(fused_computation->parameter_instruction(i),
std::move(gen));
fused_emitter.BindGenerator(hlo->fused_parameter(i), std::move(gen));
}
IrArray::Index untiled_index = GetUnnormalizedIndex(
index, output_arrays[0].GetShape(), &b_, mapping_scheme);
llvm_ir::ElementGenerator output_generator =
*fused_emitter.GetGenerator(fused_computation->root_instruction());
*fused_emitter.GetGenerator(hlo->fused_expression_root());
llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
if (output_arrays.size() > 1) {
if (hlo->IsMultiOutputFusion()) {
DCHECK(output_value->getType()->isStructTy());
DCHECK_EQ(output_value->getType()->getStructNumElements(),
output_arrays.size());
for (int64 i = 0; i < output_arrays.size() - 1; ++i) {
for (int64 i = 0; i < output_arrays.size(); ++i) {
output_arrays[i].EmitWriteArrayElement(
untiled_index, ExtractValue(output_value, i), &b_);
}
@ -3867,15 +3809,10 @@ llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
// TODO(b/33320379): Here each block transposes 1 tile. It may be more
// efficient to launch fewer blocks so each transposes many tiles.
void IrEmitterUnnested::EmitHlo021Tile(
mlir::Operation* op, Thunk* kernel_thunk, const MlirEmitterContext& context,
absl::Span<const llvm_ir::IrArray> operand_arrays,
absl::Span<const llvm_ir::IrArray> output_arrays,
HloInstruction* hlo, Thunk* kernel_thunk,
absl::Span<const int64> reduced_output_dims,
absl::Span<const int64> tiled_param_ids) {
constexpr int kNumRows = 4;
std::string name = mlir::GetNameFromLoc(op->getLoc());
KernelMappingScheme mapping_scheme(reduced_output_dims,
/*tile_sizes=*/{1, kWarpSize, kWarpSize},
/*num_threads_y=*/kNumRows,
@ -3885,16 +3822,14 @@ void IrEmitterUnnested::EmitHlo021Tile(
/*is_row_contiguous=*/false);
LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
mapping_scheme.GetThreadsPerBlock());
llvm::Type* index_type =
GetIndexTypeForKernelFromMlir(op, launch_dimensions.launch_bound(), &b_);
GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_);
std::vector<IrArray> param_arrays;
// For each tiled parameter, cast its input IrArray to the corresponding
// reduced shape and keep the reduced shape live during IR emission.
std::vector<IrArray> param_in_reduced_shape_arrays;
std::vector<llvm::Value*> param_shmem_buffers(context.operand_shapes.size(),
nullptr);
std::vector<llvm::Value*> param_shmem_buffers(hlo->operand_count(), nullptr);
auto get_shared_memory_buffer = [&](llvm::Type* elem_ty,
absl::string_view buffer_name) {
@ -3911,18 +3846,20 @@ void IrEmitterUnnested::EmitHlo021Tile(
buffer_type, buffer_name);
};
for (int64 id = 0; id < context.operand_shapes.size(); id++) {
const Shape& param_shape = context.operand_shapes[id];
param_arrays.push_back(operand_arrays[id]);
for (int64 id = 0; id < hlo->operand_count(); id++) {
const HloInstruction* param = hlo->operand(id);
param_arrays.push_back(GetIrArray(*param, *hlo));
if (absl::c_linear_search(tiled_param_ids, id)) {
param_shmem_buffers[id] = get_shared_memory_buffer(
llvm_ir::PrimitiveTypeToIrType(param_shape.element_type(), module_),
IrName(name, StrCat("tile", id)));
param_shmem_buffers[id] =
get_shared_memory_buffer(llvm_ir::PrimitiveTypeToIrType(
param->shape().element_type(), module_),
IrName(hlo, StrCat("tile", id)));
VLOG(3) << "Added shmem buffer for parameter " << id << ": "
<< llvm_ir::DumpToString(*param_shmem_buffers[id]);
Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
param_shape.element_type(), Permute({0, 2, 1}, reduced_output_dims));
param->shape().element_type(),
Permute({0, 2, 1}, reduced_output_dims));
param_in_reduced_shape_arrays.push_back(
param_arrays[id].CastToShape(reduced_shape, &b_));
} else {
@ -3933,18 +3870,13 @@ void IrEmitterUnnested::EmitHlo021Tile(
EmitElementFunction element_generator =
[&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
llvm::Value* x_loc, int64 x_iter_num) {
if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(op)) {
CHECK_EQ(1, context.output_shapes.size());
EmitTileElementForCopy(context.output_shapes[0], output_arrays[0],
index, mapping_scheme, y_loc, x_loc,
if (hlo->opcode() == HloOpcode::kCopy) {
EmitTileElementForCopy(hlo, index, mapping_scheme, y_loc, x_loc,
param_shmem_buffers);
} else if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
EmitTileElementForFusion(fusion, operand_arrays, output_arrays, index,
mapping_scheme, y_loc, x_loc,
param_shmem_buffers);
} else {
op->dump();
LOG(FATAL) << "Unexpected op type";
CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
EmitTileElementForFusion(hlo, index, mapping_scheme, y_loc, x_loc,
param_shmem_buffers);
}
};
@ -3971,9 +3903,9 @@ void IrEmitterUnnested::EmitHlo021Tile(
llvm::Value* x_loc, int64 /*x_iter_num*/) {
for (int64 id : tiled_param_ids) {
IrArray& input_in_logical_shape =
param_in_reduced_shape_arrays.at(id);
param_in_reduced_shape_arrays[id];
llvm::Value* shmem_buffer = param_shmem_buffers.at(id);
llvm::Value* shmem_buffer = param_shmem_buffers[id];
llvm::Value* zero =
llvm::ConstantInt::get(index_type, 0);
// TODO(jlebar): Add AA metadata to this store. Tile
@ -4010,11 +3942,10 @@ void IrEmitterUnnested::EmitHlo021Tile(
// the hopes of reducing register pressure, since we touch
// threadIdx.x and blockIdx.x at the beginning of the kernel
// *anyway*.
if (output_arrays.size() > 1) {
if (hlo->IsMultiOutputFusion()) {
KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
llvm_ir::EmitTuple(output_arrays.back(),
output_arrays.subspan(0, output_arrays.size() - 1),
&b_);
llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo),
ConstructIrArrayForOutputs(*hlo), &b_);
});
}
@ -4024,7 +3955,6 @@ void IrEmitterUnnested::EmitHlo021Tile(
}
namespace {
// A recursive function to inspect the users of a parameter to determine
// whether it's safe for a parameter to participate in a shared-memory
// transpose.
@ -4063,29 +3993,14 @@ namespace {
// a reduce operations. In this case, the above description on "output" apply
// to the result of such a use-chain, which provides the input to the reduce
// operation.
bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) {
if (mlir::isa<mlir::TensorStoreOp>(op)) {
return true;
bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) {
if (hlo->IsElementwise()) {
return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
return IsInstructionSafeForShmemTranspose(user);
});
}
HloOpcode opcode;
if (mlir::isa<mlir::TensorLoadOp>(op)) {
opcode = HloOpcode::kParameter;
} else {
opcode = *mlir::MhloToHloOpcode(op);
}
if (HloInstruction::IsOpElementwise(opcode)) {
for (mlir::Value v : op->getResults()) {
for (mlir::OpOperand use : v.getUsers()) {
if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
return false;
}
}
}
return true;
}
switch (opcode) {
switch (hlo->opcode()) {
// Non-elementwise instructions that don't cause the shmem transpose
// to be unsafe, including the instructions that don't currently fuse.
case HloOpcode::kGetDimensionSize:
@ -4097,14 +4012,9 @@ bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) {
case HloOpcode::kParameter:
case HloOpcode::kTuple:
case HloOpcode::kTupleSelect:
for (mlir::Value v : op->getResults()) {
for (mlir::OpOperand use : v.getUsers()) {
if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
return false;
}
}
}
return true;
return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
return IsInstructionSafeForShmemTranspose(user);
});
default:
return false;
@ -4123,21 +4033,15 @@ bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) {
// preloaded tile. We inspect all the transitive users of the input parameter
// up to the fusion root instruction to see if we can find any instruction
// that can make preloading the input tile unsafe.
std::vector<int64> FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,
std::vector<int64> FilterInputsForShmemTranspose(const HloInstruction* fusion,
std::vector<int64> input_ids) {
std::vector<mlir::Value> params;
fusion.region().walk([&](mlir::TensorLoadOp load) {
CHECK(load.memref().getParentRegion() != &fusion.region())
<< "TensorLoadOp shows should be only expected for accessing captured "
"memrefs.";
params.push_back(load);
});
std::vector<int64> filtered_input_ids;
for (int64 input_id : input_ids) {
mlir::Value input = params.at(input_id);
if (IsInstructionSafeForShmemTranspose(input.getDefiningOp())) {
filtered_input_ids.push_back(input_id);
for (int64 i = 0; i < input_ids.size(); ++i) {
const HloInstruction* input = fusion->fused_parameter(input_ids[i]);
if (IsInstructionSafeForShmemTranspose(input)) {
filtered_input_ids.push_back(input_ids[i]);
} else {
VLOG(10) << "Input not safe for shmem transpose " << input->ToString();
}
}
return filtered_input_ids;
@ -4145,23 +4049,24 @@ std::vector<int64> FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,
} // namespace
StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
MlirEmitterInput input) {
CHECK(mlir::isa<mlir::lmhlo::FusionOp>(input.op) ||
mlir::isa<mlir::lmhlo::CopyOp>(input.op));
bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
HloOpcode opcode = hlo->opcode();
MlirEmitterContext context;
context.SetOperation(input.op);
CHECK(hlo->IsLoopFusion() || opcode == HloOpcode::kCopy);
const Shape& output_shape = hlo->IsMultiOutputFusion()
? ShapeUtil::GetSubshape(hlo->shape(), {0})
: hlo->shape();
// If the output_shape is reduced to 021 shape, find all the parameters of
// the HLO that are in the corresponding 012 shape.
std::vector<int64> params_012;
optional<std::vector<int64>> reduced_dims_021;
for (int64 operand_idx = 0; operand_idx < context.operand_shapes.size();
for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
++operand_idx) {
const Shape& operand_shape = context.operand_shapes[operand_idx];
HloInstruction* operand = hlo->mutable_operand(operand_idx);
auto find_transpose_result =
ShapeUtil::FindTranspose021(operand_shape, context.output_shapes[0]);
ShapeUtil::FindTranspose021(operand->shape(), output_shape);
if (!find_transpose_result.has_value()) {
continue;
}
@ -4186,8 +4091,8 @@ StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
return false;
}
if (auto fusion_op = mlir::dyn_cast<mlir::lmhlo::FusionOp>(input.op)) {
params_012 = FilterInputsForShmemTranspose(fusion_op, params_012);
if (opcode == HloOpcode::kFusion) {
params_012 = FilterInputsForShmemTranspose(hlo, params_012);
if (params_012.empty()) {
return false;
}
@ -4215,10 +4120,10 @@ StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
constexpr int64 kShmemPerCore = 48 * 1024;
int64 shmem_used = 0;
for (int64 i = 0; i < params_012.size(); ++i) {
const Shape& operand_shape = context.operand_shapes[params_012[i]];
const HloInstruction* operand = hlo->operand(params_012[i]);
shmem_used +=
32 * 33 *
ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type());
ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
// Erase this element and everything after it from params_012.
@ -4231,15 +4136,10 @@ StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
return false;
}
std::vector<llvm_ir::IrArray> ir_arrays;
TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk,
BuildKernelThunkForMlir(input.op, input.thunk_info,
input.extra_slice, &ir_arrays));
EmitHlo021Tile(
input.op, kernel_thunk.get(), context,
absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size()),
absl::MakeSpan(ir_arrays).subspan(context.operand_shapes.size()),
*reduced_dims_021, params_012);
VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
std::unique_ptr<KernelThunk> kernel_thunk =
BuildKernelThunk(hlo, /*implements_whole_instruction=*/true);
EmitHlo021Tile(hlo, kernel_thunk.get(), *reduced_dims_021, params_012);
AddThunkToThunkSequence(std::move(kernel_thunk));
return true;
}

View File

@ -73,7 +73,6 @@ struct MlirEmitterInput {
};
// Convenience struct that contains useful data structures in MLIR emitter.
// Not all fields may be filled. It's entiredly dependent on the uses.
struct MlirEmitterContext {
void SetOperation(mlir::Operation* op);
@ -157,14 +156,11 @@ class IrEmitterUnnested : public IrEmitter,
}
Status DefaultAction(HloInstruction* hlo) override;
Status DefaultActionForMlir(MlirEmitterInput input);
// IrEmitterUnnested handles the following instructions differently from
// IrEmitter. It also mixes in some special handling for custom kernels
// via the ThunkEmitter.
Status HandleCopy(HloInstruction* copy) override;
Status EmitCopyForMlir(MlirEmitterInput input);
Status HandleConditional(HloInstruction* conditional) override;
Status HandleConvolution(HloInstruction* convolution) override;
Status HandleCustomCall(HloInstruction* custom_call) override;
@ -471,15 +467,12 @@ class IrEmitterUnnested : public IrEmitter,
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
// for the hlo instruction.
StatusOr<bool> CheckAndEmitHloWithTile021(MlirEmitterInput input);
bool CheckAndEmitHloWithTile021(HloInstruction* hlo);
// Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
// sets the corresponding launch dimensions. This is a helper to support
// the implementation of CheckAndEmitHloWithTile021.
void EmitHlo021Tile(mlir::Operation* op, Thunk* kernel_thunk,
const MlirEmitterContext& context,
absl::Span<const llvm_ir::IrArray> operand_arrays,
absl::Span<const llvm_ir::IrArray> output_arrays,
void EmitHlo021Tile(HloInstruction* hlo, Thunk* kernel_thunk,
absl::Span<const int64> reduced_output_dims,
absl::Span<const int64> tiled_param_ids);
@ -534,8 +527,7 @@ class IrEmitterUnnested : public IrEmitter,
// y_loc: The y coordinate within a tile.
// x_loc: The x coordinate within a tile.
void EmitTileElementForCopy(
const Shape& output_shape, const llvm_ir::IrArray& ir_array,
const llvm_ir::IrArray::Index& index,
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
@ -544,10 +536,7 @@ class IrEmitterUnnested : public IrEmitter,
// y_loc: The y coordinate within a tile.
// x_loc: The x coordinate within a tile.
void EmitTileElementForFusion(
mlir::lmhlo::FusionOp fusion,
absl::Span<const llvm_ir::IrArray> operand_arrays,
absl::Span<const llvm_ir::IrArray> output_arrays,
const llvm_ir::IrArray::Index& index,
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);

View File

@ -2,12 +2,12 @@
// CHECK-LABEL: entry:
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [100 x [200 x [300 x float]]]*
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [200 x [100 x [300 x float]]]*
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [200 x [100 x [300 x float]]]*
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [100 x [200 x [300 x float]]]*
// CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_7:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_8:.*]] = mul nuw nsw i32 %[[VAL_6]], 256
// CHECK: %[[VAL_8:.*]] = mul nuw nsw i32 %[[VAL_6]], 128
// CHECK: %[[VAL_9:.*]] = add nuw nsw i32 %[[VAL_8]], %[[VAL_7]]
// CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], {{.*}}
// CHECK: call void @llvm.assume(i1 %[[VAL_10]])
@ -40,24 +40,24 @@
// CHECK: b.in_bounds-after: ; preds = %[[VAL_36]], %[[VAL_38:.*]]
// CHECK: ret void
// CHECK: b.in_bounds-true: ; preds = %[[VAL_38]]
// CHECK: %[[VAL_39:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_15]], i32 %[[VAL_16]], i32 %[[VAL_13]]
// CHECK: %[[VAL_39:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_15]], i32 %[[VAL_16]], i32 %[[VAL_13]]
// CHECK: %[[VAL_40:.*]] = load float, float* %[[VAL_39]], align 4, !invariant.load !4
// CHECK: %[[VAL_41:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float*
// CHECK: %[[VAL_41:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float*
// CHECK: %[[VAL_42:.*]] = getelementptr inbounds float, float* %[[VAL_41]], i32 %[[VAL_11]]
// CHECK: store float %[[VAL_40]], float* %[[VAL_42]], align 4
// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_21]], i32 %[[VAL_22]], i32 %[[VAL_19]]
// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_21]], i32 %[[VAL_22]], i32 %[[VAL_19]]
// CHECK: %[[VAL_44:.*]] = load float, float* %[[VAL_43]], align 4, !invariant.load !4
// CHECK: %[[VAL_45:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float*
// CHECK: %[[VAL_45:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float*
// CHECK: %[[VAL_46:.*]] = getelementptr inbounds float, float* %[[VAL_45]], i32 %[[VAL_17]]
// CHECK: store float %[[VAL_44]], float* %[[VAL_46]], align 4
// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_25]]
// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_25]]
// CHECK: %[[VAL_48:.*]] = load float, float* %[[VAL_47]], align 4, !invariant.load !4
// CHECK: %[[VAL_49:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float*
// CHECK: %[[VAL_49:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float*
// CHECK: %[[VAL_50:.*]] = getelementptr inbounds float, float* %[[VAL_49]], i32 %[[VAL_23]]
// CHECK: store float %[[VAL_48]], float* %[[VAL_50]], align 4
// CHECK: %[[VAL_51:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_33]], i32 %[[VAL_34]], i32 %[[VAL_31]]
// CHECK: %[[VAL_51:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_33]], i32 %[[VAL_34]], i32 %[[VAL_31]]
// CHECK: %[[VAL_52:.*]] = load float, float* %[[VAL_51]], align 4, !invariant.load !4
// CHECK: %[[VAL_53:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float*
// CHECK: %[[VAL_53:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float*
// CHECK: %[[VAL_54:.*]] = getelementptr inbounds float, float* %[[VAL_53]], i32 %[[VAL_29]]
// CHECK: store float %[[VAL_52]], float* %[[VAL_54]], align 4
// CHECK: br label %[[VAL_37]]

View File

@ -4,10 +4,10 @@
// CHECK: %[[VAL_0:.*]] = alloca i32, align 4
// CHECK: %[[VAL_1:.*]] = alloca i32, align 4
// CHECK: %[[VAL_2:.*]] = getelementptr inbounds i8, i8* %[[VAL_3:.*]], i64 0
// CHECK: %[[VAL_4:.*]] = bitcast i8* %[[VAL_2]] to [100 x [200 x float]]*
// CHECK: %[[VAL_4:.*]] = bitcast i8* %[[VAL_2]] to [200 x [100 x float]]*
// CHECK: %[[VAL_5:.*]] = getelementptr inbounds i8, i8* %[[VAL_6:.*]], i64 0
// CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [200 x [100 x float]]*
// CHECK: %[[VAL_8:.*]] = bitcast [100 x [200 x float]]* %[[VAL_4]] to [1 x [100 x [200 x float]]]*
// CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [100 x [200 x float]]*
// CHECK: %[[VAL_8:.*]] = bitcast [100 x [200 x float]]* %[[VAL_7]] to [1 x [100 x [200 x float]]]*
// CHECK: %[[VAL_9:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2
// CHECK: %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 32
// CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_9]], 32
@ -88,7 +88,7 @@
// CHECK: output_x_in_tile-true: ; preds = %[[VAL_59]]
// CHECK: %[[VAL_72:.*]] = getelementptr [32 x [33 x float]], [32 x [33 x float]] addrspace(3)* @b.tile0, i64 0, i32 %[[VAL_65]], i32 %[[VAL_63]]
// CHECK: %[[VAL_73:.*]] = load float, float addrspace(3)* %[[VAL_72]], align 4
// CHECK: %[[VAL_74:.*]] = bitcast [200 x [100 x float]]* %[[VAL_7]] to [1 x [200 x [100 x float]]]*
// CHECK: %[[VAL_74:.*]] = bitcast [200 x [100 x float]]* %[[VAL_4]] to [1 x [200 x [100 x float]]]*
// CHECK: %[[VAL_75:.*]] = getelementptr inbounds [1 x [200 x [100 x float]]], [1 x [200 x [100 x float]]]* %[[VAL_74]], i32 0, i32 0, i32 %[[VAL_64]], i32 %[[VAL_66]]
// CHECK: store float %[[VAL_73]], float* %[[VAL_75]], align 4
// CHECK: br label %[[VAL_55]]