[XLA/GPU] Migrate Copy emitters to take LMHLO.
PiperOrigin-RevId: 342215820 Change-Id: Ibf9a16c30a5f62f8177bed98b000fe59e5192c93
This commit is contained in:
parent
fce727c889
commit
fbccb57c1f
@ -268,7 +268,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir:name_utils",
|
"//tensorflow/compiler/mlir:name_utils",
|
||||||
"//tensorflow/compiler/mlir/hlo",
|
"//tensorflow/compiler/mlir/hlo",
|
||||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||||
"//tensorflow/compiler/mlir/xla:hlo_module_importer",
|
|
||||||
"//tensorflow/compiler/mlir/xla:hlo_utils",
|
"//tensorflow/compiler/mlir/xla:hlo_utils",
|
||||||
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
|
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
|
||||||
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
|
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
|
||||||
|
@ -44,7 +44,6 @@ limitations under the License.
|
|||||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||||
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
#include "tensorflow/compiler/mlir/utils/name_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
|
||||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||||
@ -547,39 +546,6 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
|
|||||||
return IrEmitter::DefaultAction(hlo);
|
return IrEmitter::DefaultAction(hlo);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::DefaultActionForMlir(MlirEmitterInput input) {
|
|
||||||
// Replace unnested op with a fused nested op.
|
|
||||||
//
|
|
||||||
// TODO(timshen): Ultimately this should be a pass. It's currently not a pass,
|
|
||||||
// because we don't have a fully functioning LMHLO graph yet.
|
|
||||||
|
|
||||||
mlir::Location loc = input.op->getLoc();
|
|
||||||
mlir::lmhlo::FusionOp fusion = nullptr;
|
|
||||||
Shape output_shape;
|
|
||||||
if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(input.op)) {
|
|
||||||
fusion = mlir::OpBuilder(copy).create<mlir::lmhlo::FusionOp>(
|
|
||||||
loc, llvm::ArrayRef<mlir::NamedAttribute>());
|
|
||||||
copy.getOperation()->moveBefore(&fusion.region().front().back());
|
|
||||||
mlir::OpBuilder b(copy);
|
|
||||||
auto operand = b.create<mlir::TensorLoadOp>(loc, copy.operand());
|
|
||||||
HloFunctionImporter::SetLayoutForMlir(
|
|
||||||
operand, TypeToShape(copy.operand().getType()));
|
|
||||||
auto fused_copy = b.create<mlir::mhlo::CopyOp>(loc, operand);
|
|
||||||
output_shape = TypeToShape(copy.output().getType());
|
|
||||||
HloFunctionImporter::SetLayoutForMlir(fused_copy, output_shape);
|
|
||||||
b.create<mlir::TensorStoreOp>(loc, fused_copy, copy.output());
|
|
||||||
copy.getOperation()->erase();
|
|
||||||
} else {
|
|
||||||
input.op->dump();
|
|
||||||
LOG(FATAL) << "Unimplemented default action for mlir op";
|
|
||||||
}
|
|
||||||
input.op = fusion;
|
|
||||||
auto ret = EmitLoopFusionFromMlir(
|
|
||||||
input, output_shape,
|
|
||||||
ComputeMaxUnrollFactor(output_shape, hlo_module_config_));
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
|
Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
|
||||||
AddThunkToThunkSequence(
|
AddThunkToThunkSequence(
|
||||||
BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
|
BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
|
||||||
@ -1184,10 +1150,7 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
|||||||
CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop)
|
CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop)
|
||||||
<< ": " << fusion->ToString();
|
<< ": " << fusion->ToString();
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fusion));
|
if (CheckAndEmitHloWithTile021(fusion)) {
|
||||||
TF_ASSIGN_OR_RETURN(const bool matched_021,
|
|
||||||
CheckAndEmitHloWithTile021(input));
|
|
||||||
if (matched_021) {
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1196,46 +1159,35 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
|
|||||||
unroll_factor = ComputeMaxUnrollFactor(fusion);
|
unroll_factor = ComputeMaxUnrollFactor(fusion);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fusion));
|
||||||
return EmitLoopFusionFromMlir(input, fusion->shape(), unroll_factor);
|
return EmitLoopFusionFromMlir(input, fusion->shape(), unroll_factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
|
Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
|
||||||
TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(copy));
|
CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape()));
|
||||||
return EmitCopyForMlir(input);
|
const BufferAssignment& buffer_assignment =
|
||||||
}
|
ir_emitter_context_->buffer_assignment();
|
||||||
|
if (LayoutUtil::Equal(copy->operand(0)->shape().layout(),
|
||||||
Status IrEmitterUnnested::EmitCopyForMlir(MlirEmitterInput input) {
|
copy->shape().layout()) &&
|
||||||
auto copy = mlir::cast<mlir::lmhlo::CopyOp>(input.op);
|
buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) {
|
||||||
auto operand_shape = TypeToShape(copy.operand().getType());
|
|
||||||
auto output_shape = TypeToShape(copy.output().getType());
|
|
||||||
|
|
||||||
CHECK(ShapeUtil::Compatible(operand_shape, output_shape));
|
|
||||||
absl::Span<const BufferAllocation> allocations(
|
|
||||||
ir_emitter_context_->buffer_assignment().Allocations());
|
|
||||||
|
|
||||||
auto maybe_slice = GetAllocationSliceForMlir(copy.operand(), allocations);
|
|
||||||
if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) &&
|
|
||||||
maybe_slice.ok()) {
|
|
||||||
// Copy the operand into the output if it's not the same buffer already.
|
// Copy the operand into the output if it's not the same buffer already.
|
||||||
auto operand_buffer = *maybe_slice;
|
auto operand_buffer = GetAllocationSlice(*copy->operand(0));
|
||||||
auto destination_buffer =
|
auto destination_buffer = GetAllocationSlice(*copy);
|
||||||
*GetAllocationSliceForMlir(copy.output(), allocations);
|
|
||||||
if (operand_buffer != destination_buffer) {
|
if (operand_buffer != destination_buffer) {
|
||||||
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
|
AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||||
input.thunk_info,
|
GetThunkInfo(copy),
|
||||||
/*source_address=*/operand_buffer,
|
/*source_address=*/operand_buffer,
|
||||||
/*destination_buffer=*/destination_buffer,
|
/*destination_buffer=*/destination_buffer,
|
||||||
/*mem_size=*/
|
/*mem_size=*/
|
||||||
ByteSizeOf(operand_shape)));
|
ByteSizeOf(copy->operand(0)->shape())));
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
TF_ASSIGN_OR_RETURN(bool matched_021, CheckAndEmitHloWithTile021(input));
|
if (CheckAndEmitHloWithTile021(copy)) {
|
||||||
if (matched_021) {
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
return DefaultActionForMlir(input);
|
return IrEmitter::HandleCopy(copy);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::EmitExtraOutputsForReduce(
|
Status IrEmitterUnnested::EmitExtraOutputsForReduce(
|
||||||
@ -2552,15 +2504,11 @@ static void GetFusionOperandsAndOutputs(mlir::lmhlo::FusionOp fusion,
|
|||||||
std::vector<mlir::Value>* operands,
|
std::vector<mlir::Value>* operands,
|
||||||
std::vector<mlir::Value>* outputs) {
|
std::vector<mlir::Value>* outputs) {
|
||||||
fusion.region().walk([&](mlir::TensorLoadOp load) {
|
fusion.region().walk([&](mlir::TensorLoadOp load) {
|
||||||
CHECK(load.memref().getParentRegion() != &fusion.region())
|
CHECK(load.memref().getParentRegion() != &fusion.region());
|
||||||
<< "TensorLoadOp shows should be only expected for accessing captured "
|
|
||||||
"memrefs.";
|
|
||||||
operands->push_back(load.memref());
|
operands->push_back(load.memref());
|
||||||
});
|
});
|
||||||
fusion.region().walk([&](mlir::TensorStoreOp store) {
|
fusion.region().walk([&](mlir::TensorStoreOp store) {
|
||||||
CHECK(store.memref().getParentRegion() != &fusion.region())
|
CHECK(store.memref().getParentRegion() != &fusion.region());
|
||||||
<< "TensorStoreOp shows should be only expected for accessing captured "
|
|
||||||
"memrefs.";
|
|
||||||
outputs->push_back(store.memref());
|
outputs->push_back(store.memref());
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@ -3170,16 +3118,16 @@ void IrEmitterUnnested::EmitTile(
|
|||||||
// dimensions that play the same role in the transpose.
|
// dimensions that play the same role in the transpose.
|
||||||
// mapping_scheme: Kernel mapping scheme specifying the tiling
|
// mapping_scheme: Kernel mapping scheme specifying the tiling
|
||||||
void IrEmitterUnnested::EmitTileElementForCopy(
|
void IrEmitterUnnested::EmitTileElementForCopy(
|
||||||
const Shape& output_shape, const llvm_ir::IrArray& output_array,
|
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||||
const llvm_ir::IrArray::Index& index,
|
|
||||||
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
||||||
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
|
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
|
||||||
// TODO(jlebar): Add AA metadata to this load.
|
// TODO(jlebar): Add AA metadata to this load.
|
||||||
llvm::Instruction* load_from_shmem_buffer =
|
llvm::Instruction* load_from_shmem_buffer =
|
||||||
Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
|
Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
|
||||||
"output_element");
|
"output_element");
|
||||||
|
llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo);
|
||||||
Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||||
output_shape.element_type(), mapping_scheme.GetDimsInElems());
|
hlo->shape().element_type(), mapping_scheme.GetDimsInElems());
|
||||||
// When the output_reduced_shape is a 0-2-1 transpose of the input shape,
|
// When the output_reduced_shape is a 0-2-1 transpose of the input shape,
|
||||||
// the 0-2-1 transpose is achieved through EmitWriteArrayElement.
|
// the 0-2-1 transpose is achieved through EmitWriteArrayElement.
|
||||||
output_array.CastToShape(output_reduced_shape, &b_)
|
output_array.CastToShape(output_reduced_shape, &b_)
|
||||||
@ -3217,19 +3165,14 @@ static IrArray::Index GetUnnormalizedIndex(
|
|||||||
// the same role in the transpose.
|
// the same role in the transpose.
|
||||||
// kernel_info: Other information to support the kernel code generation.
|
// kernel_info: Other information to support the kernel code generation.
|
||||||
void IrEmitterUnnested::EmitTileElementForFusion(
|
void IrEmitterUnnested::EmitTileElementForFusion(
|
||||||
mlir::lmhlo::FusionOp fusion,
|
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||||
absl::Span<const llvm_ir::IrArray> operand_arrays,
|
|
||||||
absl::Span<const llvm_ir::IrArray> output_arrays,
|
|
||||||
const llvm_ir::IrArray::Index& index,
|
|
||||||
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
||||||
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
|
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
|
||||||
const HloComputation* fused_computation =
|
std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
|
||||||
*GetOrCreateSubComputationFromRegion(&fusion.region(),
|
|
||||||
/*is_fusion*/ true);
|
|
||||||
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
|
GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
|
||||||
GetNestedComputer());
|
GetNestedComputer());
|
||||||
FusedIrEmitter fused_emitter(&elem_emitter);
|
FusedIrEmitter fused_emitter(&elem_emitter);
|
||||||
for (int i = 0; i < operand_arrays.size(); i++) {
|
for (int i = 0; i < hlo->operand_count(); i++) {
|
||||||
llvm_ir::ElementGenerator gen;
|
llvm_ir::ElementGenerator gen;
|
||||||
if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) {
|
if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) {
|
||||||
gen = [this, param_tile_buffer, x_loc,
|
gen = [this, param_tile_buffer, x_loc,
|
||||||
@ -3246,24 +3189,23 @@ void IrEmitterUnnested::EmitTileElementForFusion(
|
|||||||
"tiled_buffer");
|
"tiled_buffer");
|
||||||
};
|
};
|
||||||
} else {
|
} else {
|
||||||
auto array = operand_arrays[i];
|
const HloInstruction* operand = hlo->operand(i);
|
||||||
gen = [this, array](llvm_ir::IrArray::Index index) {
|
gen = [this, operand, hlo](llvm_ir::IrArray::Index index) {
|
||||||
return array.EmitReadArrayElement(index, &b_);
|
return GetIrArray(*operand, *hlo).EmitReadArrayElement(index, &b_);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
fused_emitter.BindGenerator(fused_computation->parameter_instruction(i),
|
fused_emitter.BindGenerator(hlo->fused_parameter(i), std::move(gen));
|
||||||
std::move(gen));
|
|
||||||
}
|
}
|
||||||
IrArray::Index untiled_index = GetUnnormalizedIndex(
|
IrArray::Index untiled_index = GetUnnormalizedIndex(
|
||||||
index, output_arrays[0].GetShape(), &b_, mapping_scheme);
|
index, output_arrays[0].GetShape(), &b_, mapping_scheme);
|
||||||
llvm_ir::ElementGenerator output_generator =
|
llvm_ir::ElementGenerator output_generator =
|
||||||
*fused_emitter.GetGenerator(fused_computation->root_instruction());
|
*fused_emitter.GetGenerator(hlo->fused_expression_root());
|
||||||
llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
|
llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
|
||||||
if (output_arrays.size() > 1) {
|
if (hlo->IsMultiOutputFusion()) {
|
||||||
DCHECK(output_value->getType()->isStructTy());
|
DCHECK(output_value->getType()->isStructTy());
|
||||||
DCHECK_EQ(output_value->getType()->getStructNumElements(),
|
DCHECK_EQ(output_value->getType()->getStructNumElements(),
|
||||||
output_arrays.size());
|
output_arrays.size());
|
||||||
for (int64 i = 0; i < output_arrays.size() - 1; ++i) {
|
for (int64 i = 0; i < output_arrays.size(); ++i) {
|
||||||
output_arrays[i].EmitWriteArrayElement(
|
output_arrays[i].EmitWriteArrayElement(
|
||||||
untiled_index, ExtractValue(output_value, i), &b_);
|
untiled_index, ExtractValue(output_value, i), &b_);
|
||||||
}
|
}
|
||||||
@ -3867,15 +3809,10 @@ llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
|
|||||||
// TODO(b/33320379): Here each block transposes 1 tile. It may be more
|
// TODO(b/33320379): Here each block transposes 1 tile. It may be more
|
||||||
// efficient to launch fewer blocks so each transposes many tiles.
|
// efficient to launch fewer blocks so each transposes many tiles.
|
||||||
void IrEmitterUnnested::EmitHlo021Tile(
|
void IrEmitterUnnested::EmitHlo021Tile(
|
||||||
mlir::Operation* op, Thunk* kernel_thunk, const MlirEmitterContext& context,
|
HloInstruction* hlo, Thunk* kernel_thunk,
|
||||||
absl::Span<const llvm_ir::IrArray> operand_arrays,
|
|
||||||
absl::Span<const llvm_ir::IrArray> output_arrays,
|
|
||||||
absl::Span<const int64> reduced_output_dims,
|
absl::Span<const int64> reduced_output_dims,
|
||||||
absl::Span<const int64> tiled_param_ids) {
|
absl::Span<const int64> tiled_param_ids) {
|
||||||
constexpr int kNumRows = 4;
|
constexpr int kNumRows = 4;
|
||||||
|
|
||||||
std::string name = mlir::GetNameFromLoc(op->getLoc());
|
|
||||||
|
|
||||||
KernelMappingScheme mapping_scheme(reduced_output_dims,
|
KernelMappingScheme mapping_scheme(reduced_output_dims,
|
||||||
/*tile_sizes=*/{1, kWarpSize, kWarpSize},
|
/*tile_sizes=*/{1, kWarpSize, kWarpSize},
|
||||||
/*num_threads_y=*/kNumRows,
|
/*num_threads_y=*/kNumRows,
|
||||||
@ -3885,16 +3822,14 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
/*is_row_contiguous=*/false);
|
/*is_row_contiguous=*/false);
|
||||||
LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
|
LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
|
||||||
mapping_scheme.GetThreadsPerBlock());
|
mapping_scheme.GetThreadsPerBlock());
|
||||||
|
|
||||||
llvm::Type* index_type =
|
llvm::Type* index_type =
|
||||||
GetIndexTypeForKernelFromMlir(op, launch_dimensions.launch_bound(), &b_);
|
GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_);
|
||||||
std::vector<IrArray> param_arrays;
|
std::vector<IrArray> param_arrays;
|
||||||
|
|
||||||
// For each tiled parameter, cast its input IrArray to the corresponding
|
// For each tiled parameter, cast its input IrArray to the corresponding
|
||||||
// reduced shape and keep the reduced shape live during IR emission.
|
// reduced shape and keep the reduced shape live during IR emission.
|
||||||
std::vector<IrArray> param_in_reduced_shape_arrays;
|
std::vector<IrArray> param_in_reduced_shape_arrays;
|
||||||
std::vector<llvm::Value*> param_shmem_buffers(context.operand_shapes.size(),
|
std::vector<llvm::Value*> param_shmem_buffers(hlo->operand_count(), nullptr);
|
||||||
nullptr);
|
|
||||||
|
|
||||||
auto get_shared_memory_buffer = [&](llvm::Type* elem_ty,
|
auto get_shared_memory_buffer = [&](llvm::Type* elem_ty,
|
||||||
absl::string_view buffer_name) {
|
absl::string_view buffer_name) {
|
||||||
@ -3911,18 +3846,20 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
buffer_type, buffer_name);
|
buffer_type, buffer_name);
|
||||||
};
|
};
|
||||||
|
|
||||||
for (int64 id = 0; id < context.operand_shapes.size(); id++) {
|
for (int64 id = 0; id < hlo->operand_count(); id++) {
|
||||||
const Shape& param_shape = context.operand_shapes[id];
|
const HloInstruction* param = hlo->operand(id);
|
||||||
param_arrays.push_back(operand_arrays[id]);
|
param_arrays.push_back(GetIrArray(*param, *hlo));
|
||||||
|
|
||||||
if (absl::c_linear_search(tiled_param_ids, id)) {
|
if (absl::c_linear_search(tiled_param_ids, id)) {
|
||||||
param_shmem_buffers[id] = get_shared_memory_buffer(
|
param_shmem_buffers[id] =
|
||||||
llvm_ir::PrimitiveTypeToIrType(param_shape.element_type(), module_),
|
get_shared_memory_buffer(llvm_ir::PrimitiveTypeToIrType(
|
||||||
IrName(name, StrCat("tile", id)));
|
param->shape().element_type(), module_),
|
||||||
|
IrName(hlo, StrCat("tile", id)));
|
||||||
VLOG(3) << "Added shmem buffer for parameter " << id << ": "
|
VLOG(3) << "Added shmem buffer for parameter " << id << ": "
|
||||||
<< llvm_ir::DumpToString(*param_shmem_buffers[id]);
|
<< llvm_ir::DumpToString(*param_shmem_buffers[id]);
|
||||||
Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
|
||||||
param_shape.element_type(), Permute({0, 2, 1}, reduced_output_dims));
|
param->shape().element_type(),
|
||||||
|
Permute({0, 2, 1}, reduced_output_dims));
|
||||||
param_in_reduced_shape_arrays.push_back(
|
param_in_reduced_shape_arrays.push_back(
|
||||||
param_arrays[id].CastToShape(reduced_shape, &b_));
|
param_arrays[id].CastToShape(reduced_shape, &b_));
|
||||||
} else {
|
} else {
|
||||||
@ -3933,18 +3870,13 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
EmitElementFunction element_generator =
|
EmitElementFunction element_generator =
|
||||||
[&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
|
[&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
|
||||||
llvm::Value* x_loc, int64 x_iter_num) {
|
llvm::Value* x_loc, int64 x_iter_num) {
|
||||||
if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(op)) {
|
if (hlo->opcode() == HloOpcode::kCopy) {
|
||||||
CHECK_EQ(1, context.output_shapes.size());
|
EmitTileElementForCopy(hlo, index, mapping_scheme, y_loc, x_loc,
|
||||||
EmitTileElementForCopy(context.output_shapes[0], output_arrays[0],
|
|
||||||
index, mapping_scheme, y_loc, x_loc,
|
|
||||||
param_shmem_buffers);
|
param_shmem_buffers);
|
||||||
} else if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
|
|
||||||
EmitTileElementForFusion(fusion, operand_arrays, output_arrays, index,
|
|
||||||
mapping_scheme, y_loc, x_loc,
|
|
||||||
param_shmem_buffers);
|
|
||||||
} else {
|
} else {
|
||||||
op->dump();
|
CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
|
||||||
LOG(FATAL) << "Unexpected op type";
|
EmitTileElementForFusion(hlo, index, mapping_scheme, y_loc, x_loc,
|
||||||
|
param_shmem_buffers);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -3971,9 +3903,9 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
llvm::Value* x_loc, int64 /*x_iter_num*/) {
|
llvm::Value* x_loc, int64 /*x_iter_num*/) {
|
||||||
for (int64 id : tiled_param_ids) {
|
for (int64 id : tiled_param_ids) {
|
||||||
IrArray& input_in_logical_shape =
|
IrArray& input_in_logical_shape =
|
||||||
param_in_reduced_shape_arrays.at(id);
|
param_in_reduced_shape_arrays[id];
|
||||||
|
|
||||||
llvm::Value* shmem_buffer = param_shmem_buffers.at(id);
|
llvm::Value* shmem_buffer = param_shmem_buffers[id];
|
||||||
llvm::Value* zero =
|
llvm::Value* zero =
|
||||||
llvm::ConstantInt::get(index_type, 0);
|
llvm::ConstantInt::get(index_type, 0);
|
||||||
// TODO(jlebar): Add AA metadata to this store. Tile
|
// TODO(jlebar): Add AA metadata to this store. Tile
|
||||||
@ -4010,11 +3942,10 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
// the hopes of reducing register pressure, since we touch
|
// the hopes of reducing register pressure, since we touch
|
||||||
// threadIdx.x and blockIdx.x at the beginning of the kernel
|
// threadIdx.x and blockIdx.x at the beginning of the kernel
|
||||||
// *anyway*.
|
// *anyway*.
|
||||||
if (output_arrays.size() > 1) {
|
if (hlo->IsMultiOutputFusion()) {
|
||||||
KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
|
KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
|
||||||
llvm_ir::EmitTuple(output_arrays.back(),
|
llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo),
|
||||||
output_arrays.subspan(0, output_arrays.size() - 1),
|
ConstructIrArrayForOutputs(*hlo), &b_);
|
||||||
&b_);
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -4024,7 +3955,6 @@ void IrEmitterUnnested::EmitHlo021Tile(
|
|||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
// A recursive function to inspect the users of a parameter to determine
|
// A recursive function to inspect the users of a parameter to determine
|
||||||
// whether it's safe for a parameter to participate in a shared-memory
|
// whether it's safe for a parameter to participate in a shared-memory
|
||||||
// transpose.
|
// transpose.
|
||||||
@ -4063,29 +3993,14 @@ namespace {
|
|||||||
// a reduce operations. In this case, the above description on "output" apply
|
// a reduce operations. In this case, the above description on "output" apply
|
||||||
// to the result of such a use-chain, which provides the input to the reduce
|
// to the result of such a use-chain, which provides the input to the reduce
|
||||||
// operation.
|
// operation.
|
||||||
bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) {
|
bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) {
|
||||||
if (mlir::isa<mlir::TensorStoreOp>(op)) {
|
if (hlo->IsElementwise()) {
|
||||||
return true;
|
return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
|
||||||
|
return IsInstructionSafeForShmemTranspose(user);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
HloOpcode opcode;
|
switch (hlo->opcode()) {
|
||||||
if (mlir::isa<mlir::TensorLoadOp>(op)) {
|
|
||||||
opcode = HloOpcode::kParameter;
|
|
||||||
} else {
|
|
||||||
opcode = *mlir::MhloToHloOpcode(op);
|
|
||||||
}
|
|
||||||
if (HloInstruction::IsOpElementwise(opcode)) {
|
|
||||||
for (mlir::Value v : op->getResults()) {
|
|
||||||
for (mlir::OpOperand use : v.getUsers()) {
|
|
||||||
if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (opcode) {
|
|
||||||
// Non-elementwise instructions that don't cause the shmem transpose
|
// Non-elementwise instructions that don't cause the shmem transpose
|
||||||
// to be unsafe, including the instructions that don't currently fuse.
|
// to be unsafe, including the instructions that don't currently fuse.
|
||||||
case HloOpcode::kGetDimensionSize:
|
case HloOpcode::kGetDimensionSize:
|
||||||
@ -4097,14 +4012,9 @@ bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) {
|
|||||||
case HloOpcode::kParameter:
|
case HloOpcode::kParameter:
|
||||||
case HloOpcode::kTuple:
|
case HloOpcode::kTuple:
|
||||||
case HloOpcode::kTupleSelect:
|
case HloOpcode::kTupleSelect:
|
||||||
for (mlir::Value v : op->getResults()) {
|
return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
|
||||||
for (mlir::OpOperand use : v.getUsers()) {
|
return IsInstructionSafeForShmemTranspose(user);
|
||||||
if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
|
});
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
@ -4123,21 +4033,15 @@ bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) {
|
|||||||
// preloaded tile. We inspect all the transitive users of the input parameter
|
// preloaded tile. We inspect all the transitive users of the input parameter
|
||||||
// up to the fusion root instruction to see if we can find any instruction
|
// up to the fusion root instruction to see if we can find any instruction
|
||||||
// that can make preloading the input tile unsafe.
|
// that can make preloading the input tile unsafe.
|
||||||
std::vector<int64> FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,
|
std::vector<int64> FilterInputsForShmemTranspose(const HloInstruction* fusion,
|
||||||
std::vector<int64> input_ids) {
|
std::vector<int64> input_ids) {
|
||||||
std::vector<mlir::Value> params;
|
|
||||||
fusion.region().walk([&](mlir::TensorLoadOp load) {
|
|
||||||
CHECK(load.memref().getParentRegion() != &fusion.region())
|
|
||||||
<< "TensorLoadOp shows should be only expected for accessing captured "
|
|
||||||
"memrefs.";
|
|
||||||
params.push_back(load);
|
|
||||||
});
|
|
||||||
|
|
||||||
std::vector<int64> filtered_input_ids;
|
std::vector<int64> filtered_input_ids;
|
||||||
for (int64 input_id : input_ids) {
|
for (int64 i = 0; i < input_ids.size(); ++i) {
|
||||||
mlir::Value input = params.at(input_id);
|
const HloInstruction* input = fusion->fused_parameter(input_ids[i]);
|
||||||
if (IsInstructionSafeForShmemTranspose(input.getDefiningOp())) {
|
if (IsInstructionSafeForShmemTranspose(input)) {
|
||||||
filtered_input_ids.push_back(input_id);
|
filtered_input_ids.push_back(input_ids[i]);
|
||||||
|
} else {
|
||||||
|
VLOG(10) << "Input not safe for shmem transpose " << input->ToString();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return filtered_input_ids;
|
return filtered_input_ids;
|
||||||
@ -4145,23 +4049,24 @@ std::vector<int64> FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
|
bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
|
||||||
MlirEmitterInput input) {
|
HloOpcode opcode = hlo->opcode();
|
||||||
CHECK(mlir::isa<mlir::lmhlo::FusionOp>(input.op) ||
|
|
||||||
mlir::isa<mlir::lmhlo::CopyOp>(input.op));
|
|
||||||
|
|
||||||
MlirEmitterContext context;
|
CHECK(hlo->IsLoopFusion() || opcode == HloOpcode::kCopy);
|
||||||
context.SetOperation(input.op);
|
|
||||||
|
const Shape& output_shape = hlo->IsMultiOutputFusion()
|
||||||
|
? ShapeUtil::GetSubshape(hlo->shape(), {0})
|
||||||
|
: hlo->shape();
|
||||||
|
|
||||||
// If the output_shape is reduced to 021 shape, find all the parameters of
|
// If the output_shape is reduced to 021 shape, find all the parameters of
|
||||||
// the HLO that are in the corresponding 012 shape.
|
// the HLO that are in the corresponding 012 shape.
|
||||||
std::vector<int64> params_012;
|
std::vector<int64> params_012;
|
||||||
optional<std::vector<int64>> reduced_dims_021;
|
optional<std::vector<int64>> reduced_dims_021;
|
||||||
for (int64 operand_idx = 0; operand_idx < context.operand_shapes.size();
|
for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
|
||||||
++operand_idx) {
|
++operand_idx) {
|
||||||
const Shape& operand_shape = context.operand_shapes[operand_idx];
|
HloInstruction* operand = hlo->mutable_operand(operand_idx);
|
||||||
auto find_transpose_result =
|
auto find_transpose_result =
|
||||||
ShapeUtil::FindTranspose021(operand_shape, context.output_shapes[0]);
|
ShapeUtil::FindTranspose021(operand->shape(), output_shape);
|
||||||
if (!find_transpose_result.has_value()) {
|
if (!find_transpose_result.has_value()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -4186,8 +4091,8 @@ StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (auto fusion_op = mlir::dyn_cast<mlir::lmhlo::FusionOp>(input.op)) {
|
if (opcode == HloOpcode::kFusion) {
|
||||||
params_012 = FilterInputsForShmemTranspose(fusion_op, params_012);
|
params_012 = FilterInputsForShmemTranspose(hlo, params_012);
|
||||||
if (params_012.empty()) {
|
if (params_012.empty()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
@ -4215,10 +4120,10 @@ StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
|
|||||||
constexpr int64 kShmemPerCore = 48 * 1024;
|
constexpr int64 kShmemPerCore = 48 * 1024;
|
||||||
int64 shmem_used = 0;
|
int64 shmem_used = 0;
|
||||||
for (int64 i = 0; i < params_012.size(); ++i) {
|
for (int64 i = 0; i < params_012.size(); ++i) {
|
||||||
const Shape& operand_shape = context.operand_shapes[params_012[i]];
|
const HloInstruction* operand = hlo->operand(params_012[i]);
|
||||||
shmem_used +=
|
shmem_used +=
|
||||||
32 * 33 *
|
32 * 33 *
|
||||||
ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type());
|
ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
|
||||||
|
|
||||||
if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
|
if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
|
||||||
// Erase this element and everything after it from params_012.
|
// Erase this element and everything after it from params_012.
|
||||||
@ -4231,15 +4136,10 @@ StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<llvm_ir::IrArray> ir_arrays;
|
VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk,
|
std::unique_ptr<KernelThunk> kernel_thunk =
|
||||||
BuildKernelThunkForMlir(input.op, input.thunk_info,
|
BuildKernelThunk(hlo, /*implements_whole_instruction=*/true);
|
||||||
input.extra_slice, &ir_arrays));
|
EmitHlo021Tile(hlo, kernel_thunk.get(), *reduced_dims_021, params_012);
|
||||||
EmitHlo021Tile(
|
|
||||||
input.op, kernel_thunk.get(), context,
|
|
||||||
absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size()),
|
|
||||||
absl::MakeSpan(ir_arrays).subspan(context.operand_shapes.size()),
|
|
||||||
*reduced_dims_021, params_012);
|
|
||||||
AddThunkToThunkSequence(std::move(kernel_thunk));
|
AddThunkToThunkSequence(std::move(kernel_thunk));
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
@ -73,7 +73,6 @@ struct MlirEmitterInput {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Convenience struct that contains useful data structures in MLIR emitter.
|
// Convenience struct that contains useful data structures in MLIR emitter.
|
||||||
// Not all fields may be filled. It's entiredly dependent on the uses.
|
|
||||||
struct MlirEmitterContext {
|
struct MlirEmitterContext {
|
||||||
void SetOperation(mlir::Operation* op);
|
void SetOperation(mlir::Operation* op);
|
||||||
|
|
||||||
@ -157,14 +156,11 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status DefaultAction(HloInstruction* hlo) override;
|
Status DefaultAction(HloInstruction* hlo) override;
|
||||||
Status DefaultActionForMlir(MlirEmitterInput input);
|
|
||||||
|
|
||||||
// IrEmitterUnnested handles the following instructions differently from
|
// IrEmitterUnnested handles the following instructions differently from
|
||||||
// IrEmitter. It also mixes in some special handling for custom kernels
|
// IrEmitter. It also mixes in some special handling for custom kernels
|
||||||
// via the ThunkEmitter.
|
// via the ThunkEmitter.
|
||||||
Status HandleCopy(HloInstruction* copy) override;
|
Status HandleCopy(HloInstruction* copy) override;
|
||||||
Status EmitCopyForMlir(MlirEmitterInput input);
|
|
||||||
|
|
||||||
Status HandleConditional(HloInstruction* conditional) override;
|
Status HandleConditional(HloInstruction* conditional) override;
|
||||||
Status HandleConvolution(HloInstruction* convolution) override;
|
Status HandleConvolution(HloInstruction* convolution) override;
|
||||||
Status HandleCustomCall(HloInstruction* custom_call) override;
|
Status HandleCustomCall(HloInstruction* custom_call) override;
|
||||||
@ -471,15 +467,12 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
|
|
||||||
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
|
// Returns true if a 0-2-1 tiling algorithm is already used to emit the kernel
|
||||||
// for the hlo instruction.
|
// for the hlo instruction.
|
||||||
StatusOr<bool> CheckAndEmitHloWithTile021(MlirEmitterInput input);
|
bool CheckAndEmitHloWithTile021(HloInstruction* hlo);
|
||||||
|
|
||||||
// Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
|
// Emits a kernel for the hlo instruction using a 0-2-1 tiling algorithm and
|
||||||
// sets the corresponding launch dimensions. This is a helper to support
|
// sets the corresponding launch dimensions. This is a helper to support
|
||||||
// the implementation of CheckAndEmitHloWithTile021.
|
// the implementation of CheckAndEmitHloWithTile021.
|
||||||
void EmitHlo021Tile(mlir::Operation* op, Thunk* kernel_thunk,
|
void EmitHlo021Tile(HloInstruction* hlo, Thunk* kernel_thunk,
|
||||||
const MlirEmitterContext& context,
|
|
||||||
absl::Span<const llvm_ir::IrArray> operand_arrays,
|
|
||||||
absl::Span<const llvm_ir::IrArray> output_arrays,
|
|
||||||
absl::Span<const int64> reduced_output_dims,
|
absl::Span<const int64> reduced_output_dims,
|
||||||
absl::Span<const int64> tiled_param_ids);
|
absl::Span<const int64> tiled_param_ids);
|
||||||
|
|
||||||
@ -534,8 +527,7 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
// y_loc: The y coordinate within a tile.
|
// y_loc: The y coordinate within a tile.
|
||||||
// x_loc: The x coordinate within a tile.
|
// x_loc: The x coordinate within a tile.
|
||||||
void EmitTileElementForCopy(
|
void EmitTileElementForCopy(
|
||||||
const Shape& output_shape, const llvm_ir::IrArray& ir_array,
|
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||||
const llvm_ir::IrArray::Index& index,
|
|
||||||
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
||||||
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
|
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
|
||||||
|
|
||||||
@ -544,10 +536,7 @@ class IrEmitterUnnested : public IrEmitter,
|
|||||||
// y_loc: The y coordinate within a tile.
|
// y_loc: The y coordinate within a tile.
|
||||||
// x_loc: The x coordinate within a tile.
|
// x_loc: The x coordinate within a tile.
|
||||||
void EmitTileElementForFusion(
|
void EmitTileElementForFusion(
|
||||||
mlir::lmhlo::FusionOp fusion,
|
HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
|
||||||
absl::Span<const llvm_ir::IrArray> operand_arrays,
|
|
||||||
absl::Span<const llvm_ir::IrArray> output_arrays,
|
|
||||||
const llvm_ir::IrArray::Index& index,
|
|
||||||
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
|
||||||
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
|
llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers);
|
||||||
|
|
||||||
|
@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
// CHECK-LABEL: entry:
|
// CHECK-LABEL: entry:
|
||||||
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
// CHECK: %[[VAL_0:.*]] = getelementptr inbounds i8, i8* %[[VAL_1:.*]], i64 0
|
||||||
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [100 x [200 x [300 x float]]]*
|
// CHECK: %[[VAL_2:.*]] = bitcast i8* %[[VAL_0]] to [200 x [100 x [300 x float]]]*
|
||||||
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
|
// CHECK: %[[VAL_3:.*]] = getelementptr inbounds i8, i8* %[[VAL_4:.*]], i64 0
|
||||||
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [200 x [100 x [300 x float]]]*
|
// CHECK: %[[VAL_5:.*]] = bitcast i8* %[[VAL_3]] to [100 x [200 x [300 x float]]]*
|
||||||
// CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
|
// CHECK: %[[VAL_6:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !2
|
||||||
// CHECK: %[[VAL_7:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
|
// CHECK: %[[VAL_7:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !3
|
||||||
// CHECK: %[[VAL_8:.*]] = mul nuw nsw i32 %[[VAL_6]], 256
|
// CHECK: %[[VAL_8:.*]] = mul nuw nsw i32 %[[VAL_6]], 128
|
||||||
// CHECK: %[[VAL_9:.*]] = add nuw nsw i32 %[[VAL_8]], %[[VAL_7]]
|
// CHECK: %[[VAL_9:.*]] = add nuw nsw i32 %[[VAL_8]], %[[VAL_7]]
|
||||||
// CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], {{.*}}
|
// CHECK: %[[VAL_10:.*]] = icmp ult i32 %[[VAL_9]], {{.*}}
|
||||||
// CHECK: call void @llvm.assume(i1 %[[VAL_10]])
|
// CHECK: call void @llvm.assume(i1 %[[VAL_10]])
|
||||||
@ -40,24 +40,24 @@
|
|||||||
// CHECK: b.in_bounds-after: ; preds = %[[VAL_36]], %[[VAL_38:.*]]
|
// CHECK: b.in_bounds-after: ; preds = %[[VAL_36]], %[[VAL_38:.*]]
|
||||||
// CHECK: ret void
|
// CHECK: ret void
|
||||||
// CHECK: b.in_bounds-true: ; preds = %[[VAL_38]]
|
// CHECK: b.in_bounds-true: ; preds = %[[VAL_38]]
|
||||||
// CHECK: %[[VAL_39:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_15]], i32 %[[VAL_16]], i32 %[[VAL_13]]
|
// CHECK: %[[VAL_39:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_15]], i32 %[[VAL_16]], i32 %[[VAL_13]]
|
||||||
// CHECK: %[[VAL_40:.*]] = load float, float* %[[VAL_39]], align 4, !invariant.load !4
|
// CHECK: %[[VAL_40:.*]] = load float, float* %[[VAL_39]], align 4, !invariant.load !4
|
||||||
// CHECK: %[[VAL_41:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float*
|
// CHECK: %[[VAL_41:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float*
|
||||||
// CHECK: %[[VAL_42:.*]] = getelementptr inbounds float, float* %[[VAL_41]], i32 %[[VAL_11]]
|
// CHECK: %[[VAL_42:.*]] = getelementptr inbounds float, float* %[[VAL_41]], i32 %[[VAL_11]]
|
||||||
// CHECK: store float %[[VAL_40]], float* %[[VAL_42]], align 4
|
// CHECK: store float %[[VAL_40]], float* %[[VAL_42]], align 4
|
||||||
// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_21]], i32 %[[VAL_22]], i32 %[[VAL_19]]
|
// CHECK: %[[VAL_43:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_21]], i32 %[[VAL_22]], i32 %[[VAL_19]]
|
||||||
// CHECK: %[[VAL_44:.*]] = load float, float* %[[VAL_43]], align 4, !invariant.load !4
|
// CHECK: %[[VAL_44:.*]] = load float, float* %[[VAL_43]], align 4, !invariant.load !4
|
||||||
// CHECK: %[[VAL_45:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float*
|
// CHECK: %[[VAL_45:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float*
|
||||||
// CHECK: %[[VAL_46:.*]] = getelementptr inbounds float, float* %[[VAL_45]], i32 %[[VAL_17]]
|
// CHECK: %[[VAL_46:.*]] = getelementptr inbounds float, float* %[[VAL_45]], i32 %[[VAL_17]]
|
||||||
// CHECK: store float %[[VAL_44]], float* %[[VAL_46]], align 4
|
// CHECK: store float %[[VAL_44]], float* %[[VAL_46]], align 4
|
||||||
// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_25]]
|
// CHECK: %[[VAL_47:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_27]], i32 %[[VAL_28]], i32 %[[VAL_25]]
|
||||||
// CHECK: %[[VAL_48:.*]] = load float, float* %[[VAL_47]], align 4, !invariant.load !4
|
// CHECK: %[[VAL_48:.*]] = load float, float* %[[VAL_47]], align 4, !invariant.load !4
|
||||||
// CHECK: %[[VAL_49:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float*
|
// CHECK: %[[VAL_49:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float*
|
||||||
// CHECK: %[[VAL_50:.*]] = getelementptr inbounds float, float* %[[VAL_49]], i32 %[[VAL_23]]
|
// CHECK: %[[VAL_50:.*]] = getelementptr inbounds float, float* %[[VAL_49]], i32 %[[VAL_23]]
|
||||||
// CHECK: store float %[[VAL_48]], float* %[[VAL_50]], align 4
|
// CHECK: store float %[[VAL_48]], float* %[[VAL_50]], align 4
|
||||||
// CHECK: %[[VAL_51:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_2]], i32 0, i32 %[[VAL_33]], i32 %[[VAL_34]], i32 %[[VAL_31]]
|
// CHECK: %[[VAL_51:.*]] = getelementptr inbounds [100 x [200 x [300 x float]]], [100 x [200 x [300 x float]]]* %[[VAL_5]], i32 0, i32 %[[VAL_33]], i32 %[[VAL_34]], i32 %[[VAL_31]]
|
||||||
// CHECK: %[[VAL_52:.*]] = load float, float* %[[VAL_51]], align 4, !invariant.load !4
|
// CHECK: %[[VAL_52:.*]] = load float, float* %[[VAL_51]], align 4, !invariant.load !4
|
||||||
// CHECK: %[[VAL_53:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_5]] to float*
|
// CHECK: %[[VAL_53:.*]] = bitcast [200 x [100 x [300 x float]]]* %[[VAL_2]] to float*
|
||||||
// CHECK: %[[VAL_54:.*]] = getelementptr inbounds float, float* %[[VAL_53]], i32 %[[VAL_29]]
|
// CHECK: %[[VAL_54:.*]] = getelementptr inbounds float, float* %[[VAL_53]], i32 %[[VAL_29]]
|
||||||
// CHECK: store float %[[VAL_52]], float* %[[VAL_54]], align 4
|
// CHECK: store float %[[VAL_52]], float* %[[VAL_54]], align 4
|
||||||
// CHECK: br label %[[VAL_37]]
|
// CHECK: br label %[[VAL_37]]
|
||||||
|
@ -4,10 +4,10 @@
|
|||||||
// CHECK: %[[VAL_0:.*]] = alloca i32, align 4
|
// CHECK: %[[VAL_0:.*]] = alloca i32, align 4
|
||||||
// CHECK: %[[VAL_1:.*]] = alloca i32, align 4
|
// CHECK: %[[VAL_1:.*]] = alloca i32, align 4
|
||||||
// CHECK: %[[VAL_2:.*]] = getelementptr inbounds i8, i8* %[[VAL_3:.*]], i64 0
|
// CHECK: %[[VAL_2:.*]] = getelementptr inbounds i8, i8* %[[VAL_3:.*]], i64 0
|
||||||
// CHECK: %[[VAL_4:.*]] = bitcast i8* %[[VAL_2]] to [100 x [200 x float]]*
|
// CHECK: %[[VAL_4:.*]] = bitcast i8* %[[VAL_2]] to [200 x [100 x float]]*
|
||||||
// CHECK: %[[VAL_5:.*]] = getelementptr inbounds i8, i8* %[[VAL_6:.*]], i64 0
|
// CHECK: %[[VAL_5:.*]] = getelementptr inbounds i8, i8* %[[VAL_6:.*]], i64 0
|
||||||
// CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [200 x [100 x float]]*
|
// CHECK: %[[VAL_7:.*]] = bitcast i8* %[[VAL_5]] to [100 x [200 x float]]*
|
||||||
// CHECK: %[[VAL_8:.*]] = bitcast [100 x [200 x float]]* %[[VAL_4]] to [1 x [100 x [200 x float]]]*
|
// CHECK: %[[VAL_8:.*]] = bitcast [100 x [200 x float]]* %[[VAL_7]] to [1 x [100 x [200 x float]]]*
|
||||||
// CHECK: %[[VAL_9:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2
|
// CHECK: %[[VAL_9:.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !2
|
||||||
// CHECK: %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 32
|
// CHECK: %[[VAL_10:.*]] = urem i32 %[[VAL_9]], 32
|
||||||
// CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_9]], 32
|
// CHECK: %[[VAL_11:.*]] = udiv i32 %[[VAL_9]], 32
|
||||||
@ -88,7 +88,7 @@
|
|||||||
// CHECK: output_x_in_tile-true: ; preds = %[[VAL_59]]
|
// CHECK: output_x_in_tile-true: ; preds = %[[VAL_59]]
|
||||||
// CHECK: %[[VAL_72:.*]] = getelementptr [32 x [33 x float]], [32 x [33 x float]] addrspace(3)* @b.tile0, i64 0, i32 %[[VAL_65]], i32 %[[VAL_63]]
|
// CHECK: %[[VAL_72:.*]] = getelementptr [32 x [33 x float]], [32 x [33 x float]] addrspace(3)* @b.tile0, i64 0, i32 %[[VAL_65]], i32 %[[VAL_63]]
|
||||||
// CHECK: %[[VAL_73:.*]] = load float, float addrspace(3)* %[[VAL_72]], align 4
|
// CHECK: %[[VAL_73:.*]] = load float, float addrspace(3)* %[[VAL_72]], align 4
|
||||||
// CHECK: %[[VAL_74:.*]] = bitcast [200 x [100 x float]]* %[[VAL_7]] to [1 x [200 x [100 x float]]]*
|
// CHECK: %[[VAL_74:.*]] = bitcast [200 x [100 x float]]* %[[VAL_4]] to [1 x [200 x [100 x float]]]*
|
||||||
// CHECK: %[[VAL_75:.*]] = getelementptr inbounds [1 x [200 x [100 x float]]], [1 x [200 x [100 x float]]]* %[[VAL_74]], i32 0, i32 0, i32 %[[VAL_64]], i32 %[[VAL_66]]
|
// CHECK: %[[VAL_75:.*]] = getelementptr inbounds [1 x [200 x [100 x float]]], [1 x [200 x [100 x float]]]* %[[VAL_74]], i32 0, i32 0, i32 %[[VAL_64]], i32 %[[VAL_66]]
|
||||||
// CHECK: store float %[[VAL_73]], float* %[[VAL_75]], align 4
|
// CHECK: store float %[[VAL_73]], float* %[[VAL_75]], align 4
|
||||||
// CHECK: br label %[[VAL_55]]
|
// CHECK: br label %[[VAL_55]]
|
||||||
|
Loading…
Reference in New Issue
Block a user