[XLA/GPU] Migrate Copy emitters to take LMHLO.

PiperOrigin-RevId: 342215820
Change-Id: Ibf9a16c30a5f62f8177bed98b000fe59e5192c93
This commit is contained in:
A. Unique TensorFlower 2020-11-13 00:56:23 -08:00 committed by TensorFlower Gardener
parent fce727c889
commit fbccb57c1f
5 changed files with 104 additions and 216 deletions

View File

@ -268,7 +268,6 @@ cc_library(
"//tensorflow/compiler/mlir:name_utils", "//tensorflow/compiler/mlir:name_utils",
"//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:lhlo", "//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/xla:hlo_module_importer",
"//tensorflow/compiler/mlir/xla:hlo_utils", "//tensorflow/compiler/mlir/xla:hlo_utils",
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",

View File

@ -44,7 +44,6 @@ limitations under the License.
#include "mlir/IR/StandardTypes.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/utils/name_utils.h" #include "tensorflow/compiler/mlir/utils/name_utils.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
@ -547,39 +546,6 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
return IrEmitter::DefaultAction(hlo); return IrEmitter::DefaultAction(hlo);
} }
Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) {
// Replace unnested op with a fused nested op.
//
// TODO(timshen): Ultimately this should be a pass. It's currently not a pass,
// because we don't have a fully functioning LMHLO graph yet.
mlir::Location loc = input.op->getLoc();
mlir::lmhlo::FusionOp fusion = nullptr;
Shape output_shape;
if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(input.op)) {
fusion = mlir::OpBuilder(copy).create<mlir::lmhlo::FusionOp>(
loc, llvm::ArrayRef<mlir::NamedAttribute>());
copy.getOperation()->moveBefore(&fusion.region().front().back());
mlir::OpBuilder b(copy);
auto operand = b.create<mlir::TensorLoadOp>(loc, copy.operand());
HloFunctionImporter::SetLayoutForMlir(
operand, TypeToShape(copy.operand().getType()));
auto fused_copy = b.create<mlir::mhlo::CopyOp>(loc, operand);
output_shape = TypeToShape(copy.output().getType());
HloFunctionImporter::SetLayoutForMlir(fused_copy, output_shape);
b.create<mlir::TensorStoreOp>(loc, fused_copy, copy.output());
copy.getOperation()->erase();
} else {
input.op->dump();
LOG(FATAL) << "Unimplemented default action for mlir op";
}
input.op = fusion;
auto ret = EmitLoopFusionFromMlir(
input, output_shape,
ComputeMaxUnrollFactor(output_shape, hlo_module_config_));
return ret;
}
Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
AddThunkToThunkSequence( AddThunkToThunkSequence(
BuildKernelThunk(dot, /*implements_whole_instruction=*/true)); BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
@ -1184,10 +1150,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop) CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop)
<< ": " << fusion->ToString(); << ": " << fusion->ToString();
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fusion)); if (CheckAndEmitHloWithTile021(fusion)) {
TF_ASSIGN_OR_RETURN(const bool matched_021,
CheckAndEmitHloWithTile021(input));
if (matched_021) {
return Status::OK(); return Status::OK();
} }
@ -1196,46 +1159,35 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
unroll_factor = ComputeMaxUnrollFactor(fusion); unroll_factor = ComputeMaxUnrollFactor(fusion);
} }
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fusion));
return EmitLoopFusionFromMlir(input, fusion->shape(), unroll_factor); return EmitLoopFusionFromMlir(input, fusion->shape(), unroll_factor);
} }
Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(copy)); CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape()));
return EmitCopyForMlir(input); const BufferAssignment& buffer_assignment =
} ir_emitter_context_->buffer_assignment();
if (LayoutUtil::Equal(copy->operand(0)->shape().layout(),
Status IrEmitterUnnested::EmitCopyForMlir(MlirEmitterInput input) { copy->shape().layout()) &&
auto copy = mlir::cast<mlir::lmhlo::CopyOp>(input.op); buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) {
auto operand_shape = TypeToShape(copy.operand().getType());
auto output_shape = TypeToShape(copy.output().getType());
CHECK(ShapeUtil::Compatible(operand_shape, output_shape));
absl::Span<const BufferAllocation> allocations(
ir_emitter_context_->buffer_assignment().Allocations());
auto maybe_slice = GetAllocationSliceForMlir(copy.operand(), allocations);
if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) &&
maybe_slice.ok()) {
// Copy the operand into the output if it's not the same buffer already. // Copy the operand into the output if it's not the same buffer already.
auto operand_buffer = *maybe_slice; auto operand_buffer = GetAllocationSlice(*copy->operand(0));
auto destination_buffer = auto destination_buffer = GetAllocationSlice(*copy);
*GetAllocationSliceForMlir(copy.output(), allocations);
if (operand_buffer != destination_buffer) { if (operand_buffer != destination_buffer) {
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>( AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
input.thunk_info, GetThunkInfo(copy),
/*source_address=*/operand_buffer, /*source_address=*/operand_buffer,
/*destination_buffer=*/destination_buffer, /*destination_buffer=*/destination_buffer,
/*mem_size=*/ /*mem_size=*/
ByteSizeOf(operand_shape))); ByteSizeOf(copy->operand(0)->shape())));
} }
return Status::OK(); return Status::OK();
} }
TF_ASSIGN_OR_RETURN(bool matched_021, CheckAndEmitHloWithTile021(input)); if (CheckAndEmitHloWithTile021(copy)) {
if (matched_021) {
return Status::OK(); return Status::OK();
} }
return DefaultActionForMlir(input); return IrEmitter::HandleCopy(copy);
} }
Status IrEmitterUnnested::EmitExtraOutputsForReduce( Status IrEmitterUnnested::EmitExtraOutputsForReduce(
@ -2552,15 +2504,11 @@ static void GetFusionOperandsAndOutputs(mlir::lmhlo::FusionOp fusion,
std::vector<mlir::Value>* operands, std::vector<mlir::Value>* operands,
std::vector<mlir::Value>* outputs) { std::vector<mlir::Value>* outputs) {
fusion.region().walk([&](mlir::TensorLoadOp load) { fusion.region().walk([&](mlir::TensorLoadOp load) {
CHECK(load.memref().getParentRegion() != &fusion.region()) CHECK(load.memref().getParentRegion() != &fusion.region());
<< "TensorLoadOp shows should be only expected for accessing captured "
"memrefs.";
operands->push_back(load.memref()); operands->push_back(load.memref());
}); });
fusion.region().walk([&](mlir::TensorStoreOp store) { fusion.region().walk([&](mlir::TensorStoreOp store) {
CHECK(store.memref().getParentRegion() != &fusion.region()) CHECK(store.memref().getParentRegion() != &fusion.region());
<< "TensorStoreOp shows should be only expected for accessing captured "
"memrefs.";
outputs->push_back(store.memref()); outputs->push_back(store.memref());
}); });
} }
@ -3170,16 +3118,16 @@ void IrEmitterUnnested::EmitTile(
// dimensions that play the same role in the transpose. // dimensions that play the same role in the transpose.
// mapping_scheme: Kernel mapping scheme specifying the tiling // mapping_scheme: Kernel mapping scheme specifying the tiling
void IrEmitterUnnested::EmitTileElementForCopy( void IrEmitterUnnested::EmitTileElementForCopy(
const Shape& output_shape, const llvm_ir::IrArray& output_array, HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
const llvm_ir::IrArray::Index& index,
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) { llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
// TODO(jlebar): Add AA metadata to this load. // TODO(jlebar): Add AA metadata to this load.
llvm::Instruction* load_from_shmem_buffer = llvm::Instruction* load_from_shmem_buffer =
Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}), Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
"output_element"); "output_element");
llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo);
Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
output_shape.element_type(), mapping_scheme.GetDimsInElems()); hlo->shape().element_type(), mapping_scheme.GetDimsInElems());
// When the output_reduced_shape is a 0-2-1 transpose of the input shape, // When the output_reduced_shape is a 0-2-1 transpose of the input shape,
// the 0-2-1 transpose is achieved through EmitWriteArrayElement. // the 0-2-1 transpose is achieved through EmitWriteArrayElement.
output_array.CastToShape(output_reduced_shape, &b_) output_array.CastToShape(output_reduced_shape, &b_)
@ -3217,19 +3165,14 @@ static IrArray::Index GetUnnormalizedIndex(
// the same role in the transpose. // the same role in the transpose.
// kernel_info: Other information to support the kernel code generation. // kernel_info: Other information to support the kernel code generation.
void IrEmitterUnnested::EmitTileElementForFusion( void IrEmitterUnnested::EmitTileElementForFusion(
mlir::lmhlo::FusionOp fusion, HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
absl::Span<const llvm_ir::IrArray> operand_arrays,
absl::Span<const llvm_ir::IrArray> output_arrays,
const llvm_ir::IrArray::Index& index,
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) { llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
const HloComputation* fused_computation = std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
*GetOrCreateSubComputationFromRegion(&fusion.region(),
/*is_fusion*/ true);
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_, GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
GetNestedComputer()); GetNestedComputer());
FusedIrEmitter fused_emitter(&elem_emitter); FusedIrEmitter fused_emitter(&elem_emitter);
for (int i = 0; i < operand_arrays.size(); i++) { for (int i = 0; i < hlo->operand_count(); i++) {
llvm_ir::ElementGenerator gen; llvm_ir::ElementGenerator gen;
if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) { if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) {
gen = [this, param_tile_buffer, x_loc, gen = [this, param_tile_buffer, x_loc,
@ -3246,24 +3189,23 @@ void IrEmitterUnnested::EmitTileElementForFusion(
"tiled_buffer"); "tiled_buffer");
}; };
} else { } else {
auto array = operand_arrays[i]; const HloInstruction* operand = hlo->operand(i);
gen = [this, array](llvm_ir::IrArray::Index index) { gen = [this, operand, hlo](llvm_ir::IrArray::Index index) {
return array.EmitReadArrayElement(index, &b_); return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
}; };
} }
fused_emitter.BindGenerator(fused_computation->parameter_instruction(i), fused_emitter.BindGenerator(hlo->fused_parameter(i), std::move(gen));
std::move(gen));
} }
IrArray::Index untiled_index = GetUnnormalizedIndex( IrArray::Index untiled_index = GetUnnormalizedIndex(
index, output_arrays[0].GetShape(), &b_, mapping_scheme); index, output_arrays[0].GetShape(), &b_, mapping_scheme);
llvm_ir::ElementGenerator output_generator = llvm_ir::ElementGenerator output_generator =
*fused_emitter.GetGenerator(fused_computation->root_instruction()); *fused_emitter.GetGenerator(hlo->fused_expression_root());
llvm::Value* output_value = output_generator(untiled_index).ValueOrDie(); llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
if (output_arrays.size() > 1) { if (hlo->IsMultiOutputFusion()) {
DCHECK(output_value->getType()->isStructTy()); DCHECK(output_value->getType()->isStructTy());
DCHECK_EQ(output_value->getType()->getStructNumElements(), DCHECK_EQ(output_value->getType()->getStructNumElements(),
output_arrays.size()); output_arrays.size());
for (int64 i = 0; i < output_arrays.size() - 1; ++i) { for (int64 i = 0; i < output_arrays.size(); ++i) {
output_arrays[i].EmitWriteArrayElement( output_arrays[i].EmitWriteArrayElement(
untiled_index, ExtractValue(output_value, i), &b_); untiled_index, ExtractValue(output_value, i), &b_);
} }
@ -3867,15 +3809,10 @@ llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
// TODO(b/33320379): Here each block transposes 1 tile. It may be more // TODO(b/33320379): Here each block transposes 1 tile. It may be more
// efficient to launch fewer blocks so each transposes many tiles. // efficient to launch fewer blocks so each transposes many tiles.
void IrEmitterUnnested::EmitHlo021Tile( void IrEmitterUnnested::EmitHlo021Tile(
mlir::Operation* op, Thunk* kernel_thunk, const MlirEmitterContext& context, HloInstruction* hlo, Thunk* kernel_thunk,
absl::Span<const llvm_ir::IrArray> operand_arrays,
absl::Span<const llvm_ir::IrArray> output_arrays,
absl::Span<const int64> reduced_output_dims, absl::Span<const int64> reduced_output_dims,
absl::Span<const int64> tiled_param_ids) { absl::Span<const int64> tiled_param_ids) {
constexpr int kNumRows = 4; constexpr int kNumRows = 4;
std::string name = mlir::GetNameFromLoc(op->getLoc());
KernelMappingScheme mapping_scheme(reduced_output_dims, KernelMappingScheme mapping_scheme(reduced_output_dims,
/*tile_sizes=*/{1, kWarpSize, kWarpSize}, /*tile_sizes=*/{1, kWarpSize, kWarpSize},
/*num_threads_y=*/kNumRows, /*num_threads_y=*/kNumRows,
@ -3885,16 +3822,14 @@ void IrEmitterUnnested::EmitHlo021Tile(
/*is_row_contiguous=*/false); /*is_row_contiguous=*/false);
LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(), LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
mapping_scheme.GetThreadsPerBlock()); mapping_scheme.GetThreadsPerBlock());
llvm::Type* index_type = llvm::Type* index_type =
GetIndexTypeForKernelFromMlir(op, launch_dimensions.launch_bound(), &b_); GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_);
std::vector<IrArray> param_arrays; std::vector<IrArray> param_arrays;
// For each tiled parameter, cast its input IrArray to the corresponding // For each tiled parameter, cast its input IrArray to the corresponding
// reduced shape and keep the reduced shape live during IR emission. // reduced shape and keep the reduced shape live during IR emission.
std::vector<IrArray> param_in_reduced_shape_arrays; std::vector<IrArray> param_in_reduced_shape_arrays;
std::vector<llvm::Value*> param_shmem_buffers(context.operand_shapes.size(), std::vector<llvm::Value*> param_shmem_buffers(hlo->operand_count(), nullptr);
nullptr);
auto get_shared_memory_buffer = [&](llvm::Type* elem_ty, auto get_shared_memory_buffer = [&](llvm::Type* elem_ty,
absl::string_view buffer_name) { absl::string_view buffer_name) {
@ -3911,18 +3846,20 @@ void IrEmitterUnnested::EmitHlo021Tile(
buffer_type, buffer_name); buffer_type, buffer_name);
}; };
for (int64 id = 0; id < context.operand_shapes.size(); id++) { for (int64 id = 0; id < hlo->operand_count(); id++) {
const Shape& param_shape = context.operand_shapes[id]; const HloInstruction* param = hlo->operand(id);
param_arrays.push_back(operand_arrays[id]); param_arrays.push_back(GetIrArray(*param, *hlo));
if (absl::c_linear_search(tiled_param_ids, id)) { if (absl::c_linear_search(tiled_param_ids, id)) {
param_shmem_buffers[id] = get_shared_memory_buffer( param_shmem_buffers[id] =
llvm_ir::PrimitiveTypeToIrType(param_shape.element_type(), module_), get_shared_memory_buffer(llvm_ir::PrimitiveTypeToIrType(
IrName(name, StrCat("tile", id))); param->shape().element_type(), module_),
IrName(hlo, StrCat("tile", id)));
VLOG(3) << "Added shmem buffer for parameter " << id << ": " VLOG(3) << "Added shmem buffer for parameter " << id << ": "
<< llvm_ir::DumpToString(*param_shmem_buffers[id]); << llvm_ir::DumpToString(*param_shmem_buffers[id]);
Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout( Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
param_shape.element_type(), Permute({0, 2, 1}, reduced_output_dims)); param->shape().element_type(),
Permute({0, 2, 1}, reduced_output_dims));
param_in_reduced_shape_arrays.push_back( param_in_reduced_shape_arrays.push_back(
param_arrays[id].CastToShape(reduced_shape, &b_)); param_arrays[id].CastToShape(reduced_shape, &b_));
} else { } else {
@ -3933,18 +3870,13 @@ void IrEmitterUnnested::EmitHlo021Tile(
EmitElementFunction element_generator = EmitElementFunction element_generator =
[&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc, [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
llvm::Value* x_loc, int64 x_iter_num) { llvm::Value* x_loc, int64 x_iter_num) {
if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(op)) { if (hlo->opcode() == HloOpcode::kCopy) {
CHECK_EQ(1, context.output_shapes.size()); EmitTileElementForCopy(hlo, index, mapping_scheme, y_loc, x_loc,
EmitTileElementForCopy(context.output_shapes[0], output_arrays[0],
index, mapping_scheme, y_loc, x_loc,
param_shmem_buffers);
} else if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
EmitTileElementForFusion(fusion, operand_arrays, output_arrays, index,
mapping_scheme, y_loc, x_loc,
param_shmem_buffers); param_shmem_buffers);
} else { } else {
op->dump(); CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
LOG(FATAL) << "Unexpected op type"; EmitTileElementForFusion(hlo, index, mapping_scheme, y_loc, x_loc,
param_shmem_buffers);
} }
}; };
@ -3971,9 +3903,9 @@ void IrEmitterUnnested::EmitHlo021Tile(
llvm::Value* x_loc, int64 /*x_iter_num*/) { llvm::Value* x_loc, int64 /*x_iter_num*/) {
for (int64 id : tiled_param_ids) { for (int64 id : tiled_param_ids) {
IrArray& input_in_logical_shape = IrArray& input_in_logical_shape =
param_in_reduced_shape_arrays.at(id); param_in_reduced_shape_arrays[id];
llvm::Value* shmem_buffer = param_shmem_buffers.at(id); llvm::Value* shmem_buffer = param_shmem_buffers[id];
llvm::Value* zero = llvm::Value* zero =
llvm::ConstantInt::get(index_type, 0); llvm::ConstantInt::get(index_type, 0);
// TODO(jlebar): Add AA metadata to this store. Tile // TODO(jlebar): Add AA metadata to this store. Tile
@ -4010,11 +3942,10 @@ void IrEmitterUnnested::EmitHlo021Tile(
// the hopes of reducing register pressure, since we touch // the hopes of reducing register pressure, since we touch
// threadIdx.x and blockIdx.x at the beginning of the kernel // threadIdx.x and blockIdx.x at the beginning of the kernel
// *anyway*. // *anyway*.
if (output_arrays.size() > 1) { if (hlo->IsMultiOutputFusion()) {
KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] { KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
llvm_ir::EmitTuple(output_arrays.back(), llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo),
output_arrays.subspan(0, output_arrays.size() - 1), ConstructIrArrayForOutputs(*hlo), &b_);
&b_);
}); });
} }
@ -4024,7 +3955,6 @@ void IrEmitterUnnested::EmitHlo021Tile(
} }
namespace { namespace {
// A recursive function to inspect the users of a parameter to determine // A recursive function to inspect the users of a parameter to determine
// whether it's safe for a parameter to participate in a shared-memory // whether it's safe for a parameter to participate in a shared-memory
// transpose. // transpose.
@ -4063,29 +3993,14 @@ namespace {
// a reduce operations. In this case, the above description on "output" apply // a reduce operations. In this case, the above description on "output" apply
// to the result of such a use-chain, which provides the input to the reduce // to the result of such a use-chain, which provides the input to the reduce
// operation. // operation.
bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) { bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) {
if (mlir::isa<mlir::TensorStoreOp>(op)) { if (hlo->IsElementwise()) {
return true; return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
return IsInstructionSafeForShmemTranspose(user);
});
} }
HloOpcode opcode; switch (hlo->opcode()) {
if (mlir::isa<mlir::TensorLoadOp>(op)) {
opcode = HloOpcode::kParameter;
} else {
opcode = *mlir::MhloToHloOpcode(op);
}
if (HloInstruction::IsOpElementwise(opcode)) {
for (mlir::Value v : op->getResults()) {
for (mlir::OpOperand use : v.getUsers()) {
if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
return false;
}
}
}
return true;
}
switch (opcode) {
// Non-elementwise instructions that don't cause the shmem transpose // Non-elementwise instructions that don't cause the shmem transpose
// to be unsafe, including the instructions that don't currently fuse. // to be unsafe, including the instructions that don't currently fuse.
case HloOpcode::kGetDimensionSize: case HloOpcode::kGetDimensionSize:
@ -4097,14 +4012,9 @@ bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) {
case HloOpcode::kParameter: case HloOpcode::kParameter:
case HloOpcode::kTuple: case HloOpcode::kTuple:
case HloOpcode::kTupleSelect: case HloOpcode::kTupleSelect:
for (mlir::Value v : op->getResults()) { return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
for (mlir::OpOperand use : v.getUsers()) { return IsInstructionSafeForShmemTranspose(user);
if (!IsInstructionSafeForShmemTranspose(use.getOwner())) { });
return false;
}
}
}
return true;
default: default:
return false; return false;
@ -4123,21 +4033,15 @@ bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) {
// preloaded tile. We inspect all the transitive users of the input parameter // preloaded tile. We inspect all the transitive users of the input parameter
// up to the fusion root instruction to see if we can find any instruction // up to the fusion root instruction to see if we can find any instruction
// that can make preloading the input tile unsafe. // that can make preloading the input tile unsafe.
std::vector<int64> FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion, std::vector<int64> FilterInputsForShmemTranspose(const HloInstruction* fusion,
std::vector<int64> input_ids) { std::vector<int64> input_ids) {
std::vector<mlir::Value> params;
fusion.region().walk([&](mlir::TensorLoadOp load) {
CHECK(load.memref().getParentRegion() != &fusion.region())
<< "TensorLoadOp shows should be only expected for accessing captured "
"memrefs.";
params.push_back(load);
});
std::vector<int64> filtered_input_ids; std::vector<int64> filtered_input_ids;
for (int64 input_id : input_ids) { for (int64 i = 0; i < input_ids.size(); ++i) {
mlir::Value input = params.at(input_id); const HloInstruction* input = fusion->fused_parameter(input_ids[i]);
if (IsInstructionSafeForShmemTranspose(input.getDefiningOp())) { if (IsInstructionSafeForShmemTranspose(input)) {
filtered_input_ids.push_back(input_id); filtered_input_ids.push_back(input_ids[i]);
} else {
VLOG(10) << "Input not safe for shmem transpose " << input->ToString();
} }
} }
return filtered_input_ids; return filtered_input_ids;
@ -4145,23 +4049,24 @@ std::vector<int64> FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,
} // namespace } // namespace
StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021( bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
MlirEmitterInput input) { HloOpcode opcode = hlo->opcode();
CHECK(mlir::isa<mlir::lmhlo::FusionOp>(input.op) ||
mlir::isa<mlir::lmhlo::CopyOp>(input.op));
MlirEmitterContext context; CHECK(hlo->IsLoopFusion() || opcode == HloOpcode::kCopy);
context.SetOperation(input.op);
const Shape& output_shape = hlo->IsMultiOutputFusion()
? ShapeUtil::GetSubshape(hlo->shape(), {0})
: hlo->shape();
// If the output_shape is reduced to 021 shape, find all the parameters of // If the output_shape is reduced to 021 shape, find all the parameters of
// the HLO that are in the corresponding 012 shape. // the HLO that are in the corresponding 012 shape.
std::vector<int64> params_012; std::vector<int64> params_012;
optional<std::vector<int64>> reduced_dims_021; optional<std::vector<int64>> reduced_dims_021;
for (int64 operand_idx = 0; operand_idx < context.operand_shapes.size(); for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
++operand_idx) { ++operand_idx) {
const Shape& operand_shape = context.operand_shapes[operand_idx]; HloInstruction* operand = hlo->mutable_operand(operand_idx);
auto find_transpose_result = auto find_transpose_result =
ShapeUtil::FindTranspose021(operand_shape, context.output_shapes[0]); ShapeUtil::FindTranspose021(operand->shape(), output_shape);
if (!find_transpose_result.has_value()) { if (!find_transpose_result.has_value()) {
continue; continue;
} }
@ -4186,8 +4091,8 @@ StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
return false; return false;
} }
if (auto fusion_op = mlir::dyn_cast<mlir::lmhlo::FusionOp>(input.op)) { if (opcode == HloOpcode::kFusion) {
params_012 = FilterInputsForShmemTranspose(fusion_op, params_012); params_012 = FilterInputsForShmemTranspose(hlo, params_012);
if (params_012.empty()) { if (params_012.empty()) {
return false; return false;
} }
@ -4215,10 +4120,10 @@ StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
constexpr int64 kShmemPerCore = 48 * 1024; constexpr int64 kShmemPerCore = 48 * 1024;
int64 shmem_used = 0; int64 shmem_used = 0;
for (int64 i = 0; i < params_012.size(); ++i) { for (int64 i = 0; i < params_012.size(); ++i) {
const Shape& operand_shape = context.operand_shapes[params_012[i]]; const HloInstruction* operand = hlo->operand(params_012[i]);
shmem_used += shmem_used +=
32 * 33 * 32 * 33 *
ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type()); ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
if (kMinBlocksPerCore * shmem_used > kShmemPerCore) { if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
// Erase this element and everything after it from params_012. // Erase this element and everything after it from params_012.
@ -4231,15 +4136,10 @@ StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
return false; return false;
} }
std::vector<llvm_ir::IrArray> ir_arrays; VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk, std::unique_ptr<KernelThunk> kernel_thunk =
BuildKernelThunkForMlir(input.op, input.thunk_info, BuildKernelThunk(hlo, /*implements_whole_instruction=*/true);
input.extra_slice, &ir_arrays)); EmitHlo021Tile(hlo, kernel_thunk.get(), *reduced_dims_021, params_012);
EmitHlo021Tile(
input.op, kernel_thunk.get(), context,
absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size()),
absl::MakeSpan(ir_arrays).subspan(context.operand_shapes.size()),
*reduced_dims_021, params_012);
AddThunkToThunkSequence(std::move(kernel_thunk)); AddThunkToThunkSequence(std::move(kernel_thunk));
return true; return true;
} }

View File

@ -73,7 +73,6 @@ struct MlirEmitterInput {
}; };
// Convenience struct that contains useful data structures in MLIR emitter. // Convenience struct that contains useful data structures in MLIR emitter.
// Not all fields may be filled. It's entiredly dependent on the uses.
struct MlirEmitterContext { struct MlirEmitterContext {
void SetOperation(mlir::Operation* op); void SetOperation(mlir::Operation* op);
@ -157,14 +156,11 @@ class IrEmitterUnnested : public IrEmitter,
} }
Status DefaultAction(HloInstruction* hlo) override; Status DefaultAction(HloInstruction* hlo) override;
Status DefaultActionForMlir(MlirEmitterInput input);
// IrEmitterUnnested handles the following instructions differently from // IrEmitterUnnested handles the following instructions differently from
// IrEmitter. It also mixes in some special handling for custom kernels // IrEmitter. It also mixes in some special handling for custom kernels
// via the ThunkEmitter. // via the ThunkEmitter.
Status HandleCopy(HloInstruction* copy) override; Status HandleCopy(HloInstruction* copy) override;
Status EmitCopyForMlir(MlirEmitterInput input);
Status HandleConditional(HloInstruction* conditional) override; Status HandleConditional(HloInstruction* conditional) override;
Status HandleConvolution(HloInstruction* convolution) override; Status HandleConvolution(HloInstruction* convolution) override;
Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleCustomCall(HloInstruction* custom_call) override;
@ -471,15 +467,12 @@ class IrEmitterUnnested : public IrEmitter,
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel // Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
// for the hlo instruction. // for the hlo instruction.
StatusOr<bool> CheckAndEmitHloWithTile021(MlirEmitterInput input); bool CheckAndEmitHloWithTile021(HloInstruction* hlo);
// Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and // Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
// sets the corresponding launch dimensions. This is a helper to support // sets the corresponding launch dimensions. This is a helper to support
// the implementation of CheckAndEmitHloWithTile021. // the implementation of CheckAndEmitHloWithTile021.
void EmitHlo021Tile(mlir::Operation* op, Thunk* kernel_thunk, void EmitHlo021Tile(HloInstruction* hlo, Thunk* kernel_thunk,
const MlirEmitterContext& context,
absl::Span<const llvm_ir::IrArray> operand_arrays,
absl::Span<const llvm_ir::IrArray> output_arrays,
absl::Span<const int64> reduced_output_dims, absl::Span<const int64> reduced_output_dims,
absl::Span<const int64> tiled_param_ids); absl::Span<const int64> tiled_param_ids);
@ -534,8 +527,7 @@ class IrEmitterUnnested : public IrEmitter,
// y_loc: The y coordinate within a tile. // y_loc: The y coordinate within a tile.
// x_loc: The x coordinate within a tile. // x_loc: The x coordinate within a tile.
void EmitTileElementForCopy( void EmitTileElementForCopy(
const Shape& output_shape, const llvm_ir::IrArray& ir_array, HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
const llvm_ir::IrArray::Index& index,
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers); llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
@ -544,10 +536,7 @@ class IrEmitterUnnested : public IrEmitter,
// y_loc: The y coordinate within a tile. // y_loc: The y coordinate within a tile.
// x_loc: The x coordinate within a tile. // x_loc: The x coordinate within a tile.
void EmitTileElementForFusion( void EmitTileElementForFusion(
mlir::lmhlo::FusionOp fusion, HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
absl::Span<const llvm_ir::IrArray> operand_arrays,
absl::Span<const llvm_ir::IrArray> output_arrays,
const llvm_ir::IrArray::Index& index,
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc, const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers); llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);

View File

@ -2,12 +2,12 @@
// CHECK-LABEL: entry: // CHECK-LABEL: entry:
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0 // CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [100 x [200 x [300 x float]]]* // CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [200 x [100 x [300 x float]]]*
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0 // CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [200 x [100 x [300 x float]]]* // CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [100 x [200 x [300 x float]]]*
// CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2 // CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
// CHECK: %[[VAL_7:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3 // CHECK: %[[VAL_7:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
// CHECK: %[[VAL_8:.*]] = mul nuw nsw i32 %[[VAL_6]], 256 // CHECK: %[[VAL_8:.*]] = mul nuw nsw i32 %[[VAL_6]], 128
// CHECK: %[[VAL_9:.*]] = add nuw nsw i32 %[[VAL_8]], %[[VAL_7]] // CHECK: %[[VAL_9:.*]] = add nuw nsw i32 %[[VAL_8]], %[[VAL_7]]
// CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], {{.*}} // CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], {{.*}}
// CHECK: call void @llvm.assume(i1 %[[VAL_10]]) // CHECK: call void @llvm.assume(i1 %[[VAL_10]])
@ -40,24 +40,24 @@
// CHECK: b.in_bounds-after: ; preds = %[[VAL_36]], %[[VAL_38:.*]] // CHECK: b.in_bounds-after: ; preds = %[[VAL_36]], %[[VAL_38:.*]]
// CHECK: ret void // CHECK: ret void
// CHECK: b.in_bounds-true: ; preds = %[[VAL_38]] // CHECK: b.in_bounds-true: ; preds = %[[VAL_38]]
// CHECK: %[[VAL_39:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_15]], i32 %[[VAL_16]], i32 %[[VAL_13]] // CHECK: %[[VAL_39:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_15]], i32 %[[VAL_16]], i32 %[[VAL_13]]
// CHECK: %[[VAL_40:.*]] = load float, float* %[[VAL_39]], align 4, !invariant.load !4 // CHECK: %[[VAL_40:.*]] = load float, float* %[[VAL_39]], align 4, !invariant.load !4
// CHECK: %[[VAL_41:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float* // CHECK: %[[VAL_41:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float*
// CHECK: %[[VAL_42:.*]] = getelementptr inbounds float, float* %[[VAL_41]], i32 %[[VAL_11]] // CHECK: %[[VAL_42:.*]] = getelementptr inbounds float, float* %[[VAL_41]], i32 %[[VAL_11]]
// CHECK: store float %[[VAL_40]], float* %[[VAL_42]], align 4 // CHECK: store float %[[VAL_40]], float* %[[VAL_42]], align 4
// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_21]], i32 %[[VAL_22]], i32 %[[VAL_19]] // CHECK: %[[VAL_43:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_21]], i32 %[[VAL_22]], i32 %[[VAL_19]]
// CHECK: %[[VAL_44:.*]] = load float, float* %[[VAL_43]], align 4, !invariant.load !4 // CHECK: %[[VAL_44:.*]] = load float, float* %[[VAL_43]], align 4, !invariant.load !4
// CHECK: %[[VAL_45:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float* // CHECK: %[[VAL_45:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float*
// CHECK: %[[VAL_46:.*]] = getelementptr inbounds float, float* %[[VAL_45]], i32 %[[VAL_17]] // CHECK: %[[VAL_46:.*]] = getelementptr inbounds float, float* %[[VAL_45]], i32 %[[VAL_17]]
// CHECK: store float %[[VAL_44]], float* %[[VAL_46]], align 4 // CHECK: store float %[[VAL_44]], float* %[[VAL_46]], align 4
// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_25]] // CHECK: %[[VAL_47:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_25]]
// CHECK: %[[VAL_48:.*]] = load float, float* %[[VAL_47]], align 4, !invariant.load !4 // CHECK: %[[VAL_48:.*]] = load float, float* %[[VAL_47]], align 4, !invariant.load !4
// CHECK: %[[VAL_49:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float* // CHECK: %[[VAL_49:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float*
// CHECK: %[[VAL_50:.*]] = getelementptr inbounds float, float* %[[VAL_49]], i32 %[[VAL_23]] // CHECK: %[[VAL_50:.*]] = getelementptr inbounds float, float* %[[VAL_49]], i32 %[[VAL_23]]
// CHECK: store float %[[VAL_48]], float* %[[VAL_50]], align 4 // CHECK: store float %[[VAL_48]], float* %[[VAL_50]], align 4
// CHECK: %[[VAL_51:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_33]], i32 %[[VAL_34]], i32 %[[VAL_31]] // CHECK: %[[VAL_51:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_33]], i32 %[[VAL_34]], i32 %[[VAL_31]]
// CHECK: %[[VAL_52:.*]] = load float, float* %[[VAL_51]], align 4, !invariant.load !4 // CHECK: %[[VAL_52:.*]] = load float, float* %[[VAL_51]], align 4, !invariant.load !4
// CHECK: %[[VAL_53:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float* // CHECK: %[[VAL_53:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float*
// CHECK: %[[VAL_54:.*]] = getelementptr inbounds float, float* %[[VAL_53]], i32 %[[VAL_29]] // CHECK: %[[VAL_54:.*]] = getelementptr inbounds float, float* %[[VAL_53]], i32 %[[VAL_29]]
// CHECK: store float %[[VAL_52]], float* %[[VAL_54]], align 4 // CHECK: store float %[[VAL_52]], float* %[[VAL_54]], align 4
// CHECK: br label %[[VAL_37]] // CHECK: br label %[[VAL_37]]

View File

@ -4,10 +4,10 @@
// CHECK: %[[VAL_0:.*]] = alloca i32, align 4 // CHECK: %[[VAL_0:.*]] = alloca i32, align 4
// CHECK: %[[VAL_1:.*]] = alloca i32, align 4 // CHECK: %[[VAL_1:.*]] = alloca i32, align 4
// CHECK: %[[VAL_2:.*]] = getelementptr inbounds i8, i8* %[[VAL_3:.*]], i64 0 // CHECK: %[[VAL_2:.*]] = getelementptr inbounds i8, i8* %[[VAL_3:.*]], i64 0
// CHECK: %[[VAL_4:.*]] = bitcast i8* %[[VAL_2]] to [100 x [200 x float]]* // CHECK: %[[VAL_4:.*]] = bitcast i8* %[[VAL_2]] to [200 x [100 x float]]*
// CHECK: %[[VAL_5:.*]] = getelementptr inbounds i8, i8* %[[VAL_6:.*]], i64 0 // CHECK: %[[VAL_5:.*]] = getelementptr inbounds i8, i8* %[[VAL_6:.*]], i64 0
// CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [200 x [100 x float]]* // CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [100 x [200 x float]]*
// CHECK: %[[VAL_8:.*]] = bitcast [100 x [200 x float]]* %[[VAL_4]] to [1 x [100 x [200 x float]]]* // CHECK: %[[VAL_8:.*]] = bitcast [100 x [200 x float]]* %[[VAL_7]] to [1 x [100 x [200 x float]]]*
// CHECK: %[[VAL_9:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2 // CHECK: %[[VAL_9:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2
// CHECK: %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 32 // CHECK: %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 32
// CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_9]], 32 // CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_9]], 32
@ -88,7 +88,7 @@
// CHECK: output_x_in_tile-true: ; preds = %[[VAL_59]] // CHECK: output_x_in_tile-true: ; preds = %[[VAL_59]]
// CHECK: %[[VAL_72:.*]] = getelementptr [32 x [33 x float]], [32 x [33 x float]] addrspace(3)* @b.tile0, i64 0, i32 %[[VAL_65]], i32 %[[VAL_63]] // CHECK: %[[VAL_72:.*]] = getelementptr [32 x [33 x float]], [32 x [33 x float]] addrspace(3)* @b.tile0, i64 0, i32 %[[VAL_65]], i32 %[[VAL_63]]
// CHECK: %[[VAL_73:.*]] = load float, float addrspace(3)* %[[VAL_72]], align 4 // CHECK: %[[VAL_73:.*]] = load float, float addrspace(3)* %[[VAL_72]], align 4
// CHECK: %[[VAL_74:.*]] = bitcast [200 x [100 x float]]* %[[VAL_7]] to [1 x [200 x [100 x float]]]* // CHECK: %[[VAL_74:.*]] = bitcast [200 x [100 x float]]* %[[VAL_4]] to [1 x [200 x [100 x float]]]*
// CHECK: %[[VAL_75:.*]] = getelementptr inbounds [1 x [200 x [100 x float]]], [1 x [200 x [100 x float]]]* %[[VAL_74]], i32 0, i32 0, i32 %[[VAL_64]], i32 %[[VAL_66]] // CHECK: %[[VAL_75:.*]] = getelementptr inbounds [1 x [200 x [100 x float]]], [1 x [200 x [100 x float]]]* %[[VAL_74]], i32 0, i32 0, i32 %[[VAL_64]], i32 %[[VAL_66]]
// CHECK: store float %[[VAL_73]], float* %[[VAL_75]], align 4 // CHECK: store float %[[VAL_73]], float* %[[VAL_75]], align 4
// CHECK: br label %[[VAL_55]] // CHECK: br label %[[VAL_55]]