Roll back XLA/GPU LHLO sort emitter again

It breaks an internal msan enabled test.

PiperOrigin-RevId: 326072372
Change-Id: I245525cefa4da88097725662c75ccb213a328f19
This commit is contained in:
Sanjoy Das 2020-08-11 12:08:22 -07:00 committed by TensorFlower Gardener
parent 9bc641d16c
commit 0572b205b8
16 changed files with 403 additions and 792 deletions

View File

@ -83,9 +83,6 @@ StatusOr<llvm::SmallVector<AffineMap, 1>> GetPermutationIfAvailable(
strides[dim] = accumulated_stride;
accumulated_stride *= shape.dimensions(dim);
}
if (accumulated_stride == 0) {
return llvm::SmallVector<AffineMap, 1>{};
}
return llvm::SmallVector<AffineMap, 1>{
makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())};
}

View File

@ -8,6 +8,6 @@ HloModule TestModule
ENTRY TestComputation {
x = f32[3, 2]{1,0} parameter(0)
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) {name = "copy.1"} : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> ()
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> ()
ROOT x.copy = f32[3, 2]{0,1} copy(x)
}

View File

@ -34,6 +34,7 @@ limitations under the License.
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Pass/PassOptions.h" // from @llvm-project
#include "mlir/Translation.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
@ -181,10 +182,7 @@ template <typename OpType>
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
HloInstruction* instr) {
Location loc = getLocation(instr);
std::pair<Identifier, Attribute> attrs[] = {
{Identifier::get("name", builder_.getContext()),
builder_.getStringAttr(instr->name())},
};
ArrayRef<std::pair<Identifier, Attribute>> attrs;
ArrayRef<Type> rets{};
llvm::SmallVector<Value, 4> operands;
@ -254,14 +252,15 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
return Status::OK();
}
StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(HloInstruction* instr) {
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitSortOp(
HloInstruction* instr) {
TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
auto* sort_instr = ::xla::Cast<::xla::HloSortInstruction>(instr);
sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
*sort_instr->called_computations()[0], &sort.comparator(), &builder_));
return sort;
return sort.getOperation();
}
Status LhloDialectEmitter::HandleSort(HloInstruction* instr) {

View File

@ -19,7 +19,6 @@ limitations under the License.
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/IR/StandardTypes.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
@ -42,7 +41,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
builder_(module.getContext()),
i8_type_(builder_.getIntegerType(8)) {}
::xla::StatusOr<lmhlo::SortOp> EmitSortOp(::xla::HloInstruction* instr);
::xla::StatusOr<mlir::Operation*> EmitSortOp(::xla::HloInstruction* instr);
private:
template <typename OpType>

View File

@ -254,11 +254,6 @@ cc_library(
":target_util",
":thunk",
":thunk_emitter",
"//tensorflow/compiler/mlir/hlo:lhlo",
"//tensorflow/compiler/mlir/xla:hlo_utils",
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
"//tensorflow/compiler/mlir/xla:type_to_shape",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@ -296,8 +291,6 @@ cc_library(
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Core",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:StandardOps",
],
)
@ -1166,7 +1159,6 @@ cc_library(
":target_constants",
":tree_reduction_rewriter",
":variadic_op_splitter",
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
"//tensorflow/compiler/xla:protobuf_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
@ -1225,8 +1217,6 @@ cc_library(
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@llvm-project//llvm:Core",
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
"@llvm-project//mlir:IR",
],
)

View File

@ -29,8 +29,6 @@ limitations under the License.
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "mlir/InitAllDialects.h" // from @llvm-project
#include "tensorflow/compiler/xla/protobuf_util.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
@ -518,22 +516,15 @@ static Status CompileModuleToLlvmIrImpl(
DumpHloModuleIfEnabled(*hlo_module, **buffer_assignment,
"after_optimizations");
mlir::registerAllDialects();
mlir::MLIRContext mlir_context;
IrEmitterContext ir_emitter_context(
hlo_module, buffer_assignment->get(), platform_name, gpu_device_info,
cuda_compute_capability, profile_index_map, &mlir_context,
llvm_module->get());
cuda_compute_capability, profile_index_map, llvm_module->get());
HloComputation* entry_computation = hlo_module->entry_computation();
IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation,
&ir_emitter_context);
TF_ASSIGN_OR_RETURN(
auto ir_emitter,
IrEmitterUnnested::Create(hlo_module->config(), entry_computation,
&ir_emitter_context));
TF_RETURN_IF_ERROR(ir_emitter->EmitConstantGlobals());
TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
{
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
@ -542,10 +533,9 @@ static Status CompileModuleToLlvmIrImpl(
ThunkSequence thunk_sequence;
absl::Span<HloInstruction* const> order = hlo_schedule->ThunkLaunchOrder();
for (HloInstruction* instruction : order) {
TF_RETURN_IF_ERROR(instruction->Visit(ir_emitter.get()));
TF_RETURN_IF_ERROR(ir_emitter->Postprocess(instruction));
std::unique_ptr<ThunkSequence> thunks =
ir_emitter->ConsumeThunkSequence();
TF_RETURN_IF_ERROR(instruction->Visit(&ir_emitter));
TF_RETURN_IF_ERROR(ir_emitter.Postprocess(instruction));
std::unique_ptr<ThunkSequence> thunks = ir_emitter.ConsumeThunkSequence();
// The invariants between each input HloInstruction* and output Thunk* are
// not all explicitly checked, but at least we can document them here:

View File

@ -117,11 +117,11 @@ static bool HasMeaningfulName(llvm::Value* value) {
return false;
}
llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value,
llvm::IRBuilder<>* b) {
llvm::Type* pointee_type =
llvm_ir::ShapeToIrType(shape, b->GetInsertBlock()->getModule());
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
ShapeIndexView shape_index,
llvm::Value* ir_value) {
llvm::Type* pointee_type = llvm_ir::ShapeToIrType(
ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_);
llvm::Type* dest_type = pointee_type->getPointerTo();
llvm::Value* typed_ir_value;
@ -129,17 +129,9 @@ llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value,
typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
} else {
typed_ir_value = b->CreatePointerBitCastOrAddrSpaceCast(
typed_ir_value = b_->CreatePointerBitCastOrAddrSpaceCast(
ir_value, pointee_type->getPointerTo());
}
return typed_ir_value;
}
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
ShapeIndexView shape_index,
llvm::Value* ir_value) {
auto typed_ir_value = CastToTypedValue(
ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_value, b_);
if (!HasMeaningfulName(ir_value)) {
ir_value->setName(llvm_ir::IrName(&hlo, "raw"));
}

View File

@ -116,10 +116,6 @@ class HloToIrBindings {
llvm::Value* temp_buffer_base_ = nullptr;
};
// Converts `ir_value` with type i8* to a typed LLVM Value* based on `shape`.
llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value,
llvm::IRBuilder<>* b);
} // namespace gpu
} // namespace xla

View File

@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_
#include "llvm/IR/Module.h"
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
@ -35,15 +34,13 @@ class IrEmitterContext {
const HloModule* hlo_module, const BufferAssignment* buffer_assignment,
std::string platform_name, GpuDeviceInfo gpu_device_info,
absl::optional<CudaComputeCapability> cuda_compute_capability,
const HloProfileIndexMap* profile_index_map,
mlir::MLIRContext* mlir_context, llvm::Module* llvm_module)
const HloProfileIndexMap* profile_index_map, llvm::Module* llvm_module)
: hlo_module_(hlo_module),
buffer_assignment_(buffer_assignment),
platform_name_(std::move(platform_name)),
gpu_device_info_(gpu_device_info),
cuda_compute_capability_(cuda_compute_capability),
profile_index_map_(profile_index_map),
mlir_context_(mlir_context),
llvm_module_(llvm_module) {}
// Disallow copy and assign.
IrEmitterContext(const IrEmitterContext&) = delete;
@ -60,7 +57,6 @@ class IrEmitterContext {
return cuda_compute_capability_;
}
const HloProfileIndexMap* profile_index_map() { return profile_index_map_; }
mlir::MLIRContext* mlir_context() { return mlir_context_; }
llvm::Module* llvm_module() { return llvm_module_; }
NameUniquer* name_uniquer() { return &name_uniquer_; }
@ -71,7 +67,6 @@ class IrEmitterContext {
GpuDeviceInfo gpu_device_info_;
absl::optional<CudaComputeCapability> cuda_compute_capability_;
const HloProfileIndexMap* profile_index_map_;
mlir::MLIRContext* mlir_context_;
llvm::Module* llvm_module_;
NameUniquer name_uniquer_;
};

View File

@ -37,13 +37,6 @@ limitations under the License.
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Builders.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
@ -151,86 +144,13 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
}
const BufferAllocation* GetAllocation(
mlir::BlockArgument func_arg, const BufferAssignment& buffer_assignment) {
auto func_op =
mlir::cast<mlir::FuncOp>(func_arg.getParentRegion()->getParentOp());
int64 allocation_index = func_op
.getArgAttrOfType<mlir::IntegerAttr>(
func_arg.getArgNumber(), "lmhlo.alloc")
.getValue()
.getSExtValue();
return &buffer_assignment.GetAllocation(allocation_index);
}
StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
mlir::Value v, const BufferAssignment& buffer_assignment) {
int64 size = v.getType().cast<mlir::MemRefType>().getSizeInBits() / 8;
if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
return BufferAllocation::Slice(GetAllocation(arg, buffer_assignment), 0,
size);
}
// We match two patterns here:
// * v = ViewOp(arg);
// * v = StaticMemRefCastOp(ViewOp(arg));
if (mlir::Operation* op = v.getDefiningOp()) {
if (auto cast = mlir::dyn_cast<mlir::lmhlo::StaticMemRefCastOp>(op)) {
mlir::Value source = cast.getViewSource();
op = source.getDefiningOp();
if (!op) {
return Unimplemented("StaticMemRefCastOp has to wrap an op");
}
}
if (auto view = mlir::dyn_cast<mlir::ViewOp>(op)) {
return BufferAllocation::Slice(
GetAllocation(view.source().cast<mlir::BlockArgument>(),
buffer_assignment),
mlir::cast<mlir::ConstantOp>(view.byte_shift().getDefiningOp())
.value()
.cast<mlir::IntegerAttr>()
.getValue()
.getSExtValue(),
size);
}
return Unimplemented("StaticMemRefCastOp has to wrap a ViewOp");
}
return Unimplemented(
"Operand has to be in the form of ViewOp(arg) or "
"StaticMemRefCastOp(ViewOp(arg))");
}
absl::string_view GetHloName(mlir::Operation* op) {
if (auto attr = op->getAttrOfType<mlir::StringAttr>("name")) {
auto ref = attr.getValue();
return absl::string_view(ref.data(), ref.size());
}
return "";
}
} // namespace
IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
const HloComputation* hlo_computation,
IrEmitterContext* ir_emitter_context)
: IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false),
hlo_computation_(hlo_computation),
mlir_scratch_module_(mlir::ModuleOp::create(
mlir::Builder(ir_emitter_context->mlir_context()).getUnknownLoc())),
lhlo_scratch_emitter_(ir_emitter_context_->buffer_assignment(),
*hlo_computation, mlir_scratch_module_.get()) {}
StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
const HloModuleConfig& hlo_module_config,
const HloComputation* hlo_computation,
IrEmitterContext* ir_emitter_context) {
auto emitter = std::unique_ptr<IrEmitterUnnested>(new IrEmitterUnnested(
hlo_module_config, hlo_computation, ir_emitter_context));
TF_RETURN_IF_ERROR(emitter->lhlo_scratch_emitter_.Initialize());
return std::move(emitter);
}
hlo_computation_(hlo_computation) {}
Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
bindings_.UnbindAllLocalIrValues();
@ -238,11 +158,12 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
}
llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
absl::string_view name, absl::Span<const BufferAllocation* const> args) {
const HloInstruction& inst,
absl::Span<const BufferAllocation* const> args) {
// Compute the kernel name. The opcode string may contain "-" which cannot be
// in a PTX function name, so sanitize the name before uniquifying it.
string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
llvm_ir::SanitizeFunctionName(std::string(name)));
llvm_ir::SanitizeFunctionName(inst.name()));
// Create the kernel and add it to the module.
llvm::Module* module = ir_emitter_context_->llvm_module();
@ -438,8 +359,7 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
}
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional));
AddThunkToThunkSequence(std::move(thunk));
AddThunkToThunkSequence(BuildConditionalThunk(conditional));
return Status::OK();
}
@ -1118,13 +1038,10 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
// Build ForThunk for conformant while loops, otherwise build WhileThunk.
auto config = xla_while->backend_config<WhileLoopBackendConfig>();
if (config.ok() && config.ValueOrDie().has_known_trip_count()) {
TF_ASSIGN_OR_RETURN(
auto thunk,
AddThunkToThunkSequence(
BuildForThunk(xla_while, config.ValueOrDie().known_trip_count().n()));
AddThunkToThunkSequence(std::move(thunk));
} else {
TF_ASSIGN_OR_RETURN(auto thunk, BuildWhileThunk(xla_while));
AddThunkToThunkSequence(std::move(thunk));
AddThunkToThunkSequence(BuildWhileThunk(xla_while));
}
return Status::OK();
}
@ -1347,109 +1264,39 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
return IrEmitter::HandleSelect(select);
}
StatusOr<const HloComputation*>
IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region) {
std::unique_ptr<HloModule>& module = scratch_nested_computations_[region];
if (module == nullptr) {
xla::XlaComputation xla_computation;
TF_RETURN_IF_ERROR(ConvertRegionToComputation(region, &xla_computation));
TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(
module, HloModule::CreateFromProto(xla_computation.proto(),
HloModuleConfig(program_shape)));
}
return module->entry_computation();
}
Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
MlirEmitterInput result;
TF_ASSIGN_OR_RETURN(auto sort_op, lhlo_scratch_emitter_.EmitSortOp(sort));
result.op = sort_op;
result.name = GetHloName(sort_op);
// The name in sort op has no semantics, and it's for debug only. If the name
// doesn't exist, we should use a namer (e.g. count-based).
// TODO(timshen): use a namer instead of relying on the HloInstruction names.
if (result.name.empty()) {
result.name = sort->name();
}
const auto& buffer_assignment = ir_emitter_context_->buffer_assignment();
auto& slice = result.extra_slice;
TF_ASSIGN_OR_RETURN(slice.buffer_slice,
buffer_assignment.GetUniqueSlice(sort, {}));
slice.written = true;
slice.shape = sort->shape();
result.thunk_info = GetThunkInfo(sort);
return EmitMlirSort(result);
}
Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) {
const auto& buffer_assignment = ir_emitter_context_->buffer_assignment();
auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(input.op);
int operand_count = sort_op.operands().size();
std::vector<xla::Shape> operand_shapes(operand_count);
std::vector<MlirBufferSlice> slices;
std::vector<xla::Shape> output_shapes(sort_op.output().size());
for (int i = 0; i < operand_count; i++) {
operand_shapes[i] =
TypeToShape(sort_op.operands()[i].getType().cast<mlir::MemRefType>());
}
// Craft n + 1 slices, where the first n are output parameters, and the last
// is the on-device tuple storage. We don't need n operands because sorting
// kernels are always in-place.
for (int i = 0; i < operand_count; i++) {
output_shapes[i] =
TypeToShape(sort_op.output()[i].getType().cast<mlir::MemRefType>());
MlirBufferSlice slice;
TF_ASSIGN_OR_RETURN(
slice.buffer_slice,
GetAllocationSliceForMlir(sort_op.output()[i], buffer_assignment));
slice.written = true;
slice.shape = operand_shapes[i];
slices.push_back(slice);
}
slices.push_back(input.extra_slice);
std::vector<std::unique_ptr<Thunk>> thunks;
Shape keys_shape = operand_shapes[0];
int64 dimension_to_sort = sort_op.dimension().getSExtValue();
for (int64 i = 0; i < operand_count; ++i) {
Shape keys_shape = sort->operand(0)->shape();
int64 dimension_to_sort = sort->dimensions(0);
for (int64 i = 0; i < sort->operand_count(); ++i) {
ShapeIndex shape_index =
sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
// We assume that the layout of all involved operands and outputs is the
// same.
TF_RET_CHECK(
LayoutUtil::LayoutsInShapesEqual(keys_shape, operand_shapes[i]));
TF_RET_CHECK(
LayoutUtil::LayoutsInShapesEqual(keys_shape, output_shapes[i]));
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape,
sort->operand(i)->shape()));
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
// If possible, we share buffers. If that is not possible, we need to copy
// the values, because the emitter does the sorting in-place.
TF_ASSIGN_OR_RETURN(
auto destination_buffer,
GetAllocationSliceForMlir(sort_op.output()[i], buffer_assignment));
TF_ASSIGN_OR_RETURN(
auto source_address,
GetAllocationSliceForMlir(sort_op.operands()[i], buffer_assignment));
auto destination_buffer = GetAllocationSlice(*sort, shape_index);
auto source_address = GetAllocationSlice(*sort->operand(i));
if (destination_buffer != source_address) {
// TODO(b/26783907): Figure out why we never seem to share buffers for
// key/value sort.
VLOG(2) << input.name << " requires initial D2D copy for operand " << i;
VLOG(2) << sort->name() << " requires initial D2D copy for operand " << i;
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
Thunk::ThunkInfo(),
/*source_address=*/source_address,
/*destination_buffer=*/destination_buffer,
/*mem_size=*/ShapeUtil::ByteSizeOf(operand_shapes[i])));
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape())));
}
}
uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
VLOG(2) << input.name << " requires " << num_stages << " stages.";
VLOG(2) << sort->name() << " requires " << num_stages << " stages.";
CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound);
@ -1513,10 +1360,10 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) {
// we have not enough threads, or not enough shared memory. Also it does not
// give a speedup if the tile size is < 128.
int64 total_shared_memory_needed = 0;
for (int64 i = 0; i < operand_count; ++i) {
for (int64 i = 0; i < sort->operand_count(); ++i) {
total_shared_memory_needed +=
kTileSize *
ShapeUtil::ByteSizeOfPrimitiveType(operand_shapes[i].element_type());
kTileSize * ShapeUtil::ByteSizeOfPrimitiveType(
sort->operand(i)->shape().element_type());
}
bool no_tiling =
kTileSize < 128 ||
@ -1529,7 +1376,7 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) {
"kTileSize=%d < 128, "
"kThreadsPerBlock=%d > threads_per_block_limit=%d, "
"total_shared_memory_needed=%d > shared_memory_per_block=%d",
input.name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock,
sort->name(), (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock,
ir_emitter_context_->gpu_device_info().threads_per_block_limit,
total_shared_memory_needed,
ir_emitter_context_->gpu_device_info().shared_memory_per_block);
@ -1537,38 +1384,37 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) {
uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock);
LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock);
VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block",
input.name, num_blocks, kThreadsPerBlock);
sort->name(), num_blocks, kThreadsPerBlock);
std::vector<llvm_ir::IrArray> ir_arrays;
auto emit_kernel = [&](absl::Span<const int64> xor_masks) {
VLOG(2) << absl::StreamFormat(
"%s uses kernel for xor masks [%s]", input.name,
"%s uses kernel for xor masks [%s]", sort->name(),
absl::StrJoin(xor_masks, ", ", [](std::string* out, int64 xor_mask) {
absl::StrAppendFormat(out, "0x%x", xor_mask);
}));
thunks.push_back(BuildKernelThunkForMlir(input.name, Thunk::ThunkInfo(),
slices, &ir_arrays));
thunks.push_back(
BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
LaunchDimensions launch_dimensions = xor_masks.size() > 1
? tiled_launch_dimensions
: standard_launch_dimensions;
UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
ir_emitter_context_->llvm_module());
std::vector<IrArray> values_arrays;
values_arrays.reserve(operand_count);
for (int64 i = 0; i < operand_count; ++i) {
values_arrays.push_back(ir_arrays[i]);
values_arrays.reserve(sort->operand_count());
for (int64 i = 0; i < sort->operand_count(); ++i) {
ShapeIndex shape_index =
sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
values_arrays.push_back(GetIrArray(*sort, *sort, shape_index));
}
TF_ASSIGN_OR_RETURN(
const HloComputation* comparator,
GetOrCreateSubComputationFromRegion(&sort_op.comparator()));
return llvm_ir::EmitSortInPlace(
dimension_to_sort, values_arrays, IrName(input.name), xor_masks, &b_,
dimension_to_sort, values_arrays, IrName(sort), xor_masks, &b_,
launch_dimensions,
xor_masks.size() > 1 ? num_iterations_in_sort_dim
: standard_num_iterations_in_sort_dim,
kTileSize,
[&](absl::Span<llvm::Value* const> operands, llvm::Value* output) {
return EmitCallToNestedComputation(*comparator, operands, output);
return EmitCallToNestedComputation(*sort->to_apply(), operands,
output);
});
};
std::vector<int64> xor_masks;
@ -1595,18 +1441,17 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) {
TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
}
VLOG(2) << absl::StreamFormat(
"%s requires %d thunks (including any D2D copies)", input.name,
"%s requires %d thunks (including any D2D copies)", sort->name(),
thunks.size());
AddThunkToThunkSequence(
absl::make_unique<SequentialThunk>(input.thunk_info, std::move(thunks)));
if (operand_count > 1) {
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
GetThunkInfo(sort), std::move(thunks)));
if (sort->operand_count() > 1) {
// Emit the tuple as part of the last stage of sorting.
// We are currently in the block sorted.in_bounds.after.
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
llvm_ir::EmitTuple(
ir_arrays[operand_count],
absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_);
llvm_ir::EmitTuple(GetIrArray(*sort, *sort),
ConstructIrArrayForOutputs(*sort), &b_);
}
return Status::OK();
}
@ -1744,6 +1589,24 @@ Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) {
return Status::OK();
}
// Describes how to access a particular subshape for an HLO. For instance if
// `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at
// ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is found
// at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we
// dereference twice -- first at index 3, and then at index 4 -- to get the
// address of our buffer.
struct HloBufferSlice {
const HloInstruction* instr;
ShapeIndex hlo_index;
// The root buffer to look at.
BufferAllocation::Slice buffer_slice;
// Describes how to dereference starting at that buffer to get to the buffer
// in question.
ShapeIndex gte_index;
};
// Figures out how to access the buffers for all subshapes of hlo's operands and
// for hlo itself (i.e. all the buffers produced by HLO).
//
@ -1852,22 +1715,22 @@ static std::vector<HloBufferSlice> GetHloBufferSlices(
return result;
}
std::unique_ptr<KernelThunk>
IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
absl::string_view name, Thunk::ThunkInfo thunk_info,
absl::Span<const BufferSlice* const> slices,
std::function<void(const BufferSlice*, llvm::Value*)>
bind_slice_to_ir_value) {
const auto& buffer_assn = ir_emitter_context_->buffer_assignment();
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
const HloInstruction* inst, bool implements_whole_instruction) {
const BufferAssignment& buffer_assn =
ir_emitter_context_->buffer_assignment();
std::vector<HloBufferSlice> hlo_slices =
GetHloBufferSlices(inst, buffer_assn);
// Figure out which buffer allocations need to be passed as arguments to our
// kernel. This is simply all of the allocations referenced in slices,
// kernel. This is simply all of the allocations referenced in hlo_slices,
// plus the XLA temp buffer (if we have it). We always include the temp
// buffer because even if the kernel itself doesn't use it, a nested
// subcomputation within the kernel (e.g. a kMap's computation) might.
std::unordered_set<const BufferAllocation*> buffers_needed;
for (auto* slice : slices) {
buffers_needed.insert(slice->buffer_slice.allocation());
for (const auto& hlo_buffer_slice : hlo_slices) {
buffers_needed.insert(hlo_buffer_slice.buffer_slice.allocation());
}
absl::optional<const BufferAllocation*> temp_buffer;
for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
@ -1896,7 +1759,7 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
return a->index() < b->index();
});
llvm::Function* kernel = BuildKernelPrototype(name, non_constant_buffers);
llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers);
// Build a map from a BufferAllocation to the corresponding argument in our
// kernel.
@ -1930,19 +1793,24 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
// For each buffer our kernel might want to touch, bind it to a value derived
// from our kernel args.
for (auto* slice : slices) {
const BufferAllocation::Slice& buffer_slice = slice->buffer_slice;
const ShapeIndex& gte_index = slice->gte_index;
for (const auto& hlo_buffer_slice : hlo_slices) {
const HloInstruction* instr = hlo_buffer_slice.instr;
const ShapeIndex& index = hlo_buffer_slice.hlo_index;
const BufferAllocation::Slice& slice = hlo_buffer_slice.buffer_slice;
const ShapeIndex& gte_index = hlo_buffer_slice.gte_index;
VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString()
<< " is found in slice " << slice.ToString() << " at GTE index "
<< gte_index.ToString();
llvm::Value* loc;
if (buffer_slice.allocation()->is_constant()) {
if (slice.allocation()->is_constant()) {
loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
llvm_ir::ConstantBufferAllocationToGlobalName(
*buffer_slice.allocation()));
llvm_ir::ConstantBufferAllocationToGlobalName(*slice.allocation()));
CHECK_NE(loc, nullptr);
} else {
loc = InBoundsGEP(kernel_args.at(buffer_slice.allocation()),
{b_.getInt64(buffer_slice.offset())});
loc = InBoundsGEP(kernel_args.at(slice.allocation()),
{b_.getInt64(slice.offset())});
}
// If gte_index is nonempty, we have to dereference `loc` to get to the
@ -1954,7 +1822,7 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)}));
}
bind_slice_to_ir_value(slice, loc);
bindings_.BindHloToIrValue(*instr, loc, index);
}
// Bind the temp buffer so that nested subcomputations can find it if they
@ -1966,66 +1834,9 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
}
return absl::make_unique<KernelThunk>(thunk_info, non_constant_buffers,
std::string(kernel->getName()));
}
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
const HloInstruction* inst, bool implements_whole_instruction) {
std::vector<HloBufferSlice> hlo_slices =
GetHloBufferSlices(inst, ir_emitter_context_->buffer_assignment());
std::vector<BufferSlice*> slice_ptrs;
slice_ptrs.reserve(hlo_slices.size());
for (auto& slice : hlo_slices) {
slice_ptrs.push_back(&slice);
}
return BuildKernelThunkFromBufferSlices(
inst->name(),
return absl::make_unique<KernelThunk>(
implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(),
slice_ptrs, [this](const BufferSlice* slice, llvm::Value* value) {
const HloBufferSlice* hlo_buffer_slice =
static_cast<const HloBufferSlice*>(slice);
const HloInstruction* instr = hlo_buffer_slice->instr;
const ShapeIndex& index = hlo_buffer_slice->hlo_index;
VLOG(3) << "Buffer for " << instr->ToString() << " at "
<< index.ToString() << " is found in slice "
<< hlo_buffer_slice->buffer_slice.ToString() << " at GTE index "
<< hlo_buffer_slice->gte_index.ToString();
bindings_.BindHloToIrValue(*instr, value, index);
});
}
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunkForMlir(
absl::string_view name, Thunk::ThunkInfo thunk_info,
absl::Span<const MlirBufferSlice> slices,
std::vector<llvm_ir::IrArray>* ir_arrays) {
absl::flat_hash_set<BufferAllocation::Slice> buffers_written;
std::vector<const BufferSlice*> slice_ptrs;
slice_ptrs.reserve(slices.size());
for (auto& slice : slices) {
slice_ptrs.push_back(&slice);
if (slice.written) {
buffers_written.insert(slice.buffer_slice);
}
}
ir_arrays->clear();
return BuildKernelThunkFromBufferSlices(
name, thunk_info, slice_ptrs,
[&](const BufferSlice* slice, llvm::Value* value) {
const auto& mlir_slice = static_cast<const MlirBufferSlice&>(*slice);
llvm_ir::IrArray ir_array(
CastToTypedValue(mlir_slice.shape, value, &b_), mlir_slice.shape);
if (!buffers_written.contains(slice->buffer_slice)) {
ir_array.MarkInvariantOverWholeProgram(&value->getContext());
}
ir_arrays->push_back(ir_array);
});
non_constant_buffers, std::string(kernel->getName()));
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
@ -2232,7 +2043,7 @@ Status CheckConditionalBuffersShareAllocation(
} // namespace
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
const HloInstruction* hlo) {
// Check that all while-related buffers share an allocation.
TF_CHECK_OK(CheckWhileBuffersShareAllocation(
@ -2240,26 +2051,24 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
// Generate thunk sequence for while 'condition'.
HloComputation* condition = hlo->while_condition();
TF_ASSIGN_OR_RETURN(auto ir_emitter_condition,
IrEmitterUnnested::Create(hlo_module_config_, condition,
ir_emitter_context_));
TF_RETURN_IF_ERROR(condition->Accept(ir_emitter_condition.get()));
IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition,
ir_emitter_context_);
TF_CHECK_OK(condition->Accept(&ir_emitter_condition));
// Generate thunk sequence for while 'body'.
HloComputation* body = hlo->while_body();
TF_ASSIGN_OR_RETURN(
auto ir_emitter_body,
IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
ir_emitter_context_);
TF_CHECK_OK(body->Accept(&ir_emitter_body));
return std::unique_ptr<Thunk>(new WhileThunk(
return absl::make_unique<WhileThunk>(
GetThunkInfo(hlo),
GetAllocationSlice(*condition->root_instruction()), // cond result
ir_emitter_condition->ConsumeThunkSequence(),
ir_emitter_body->ConsumeThunkSequence()));
ir_emitter_condition.ConsumeThunkSequence(),
ir_emitter_body.ConsumeThunkSequence());
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
const HloInstruction* hlo, const int64 loop_limit) {
// Check that all while-related buffers share an allocation.
TF_CHECK_OK(CheckWhileBuffersShareAllocation(
@ -2267,16 +2076,15 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
// Generate thunk sequence for while 'body' (will be used a For loop body).
HloComputation* body = hlo->while_body();
TF_ASSIGN_OR_RETURN(
auto ir_emitter_body,
IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
ir_emitter_context_);
TF_CHECK_OK(body->Accept(&ir_emitter_body));
return std::unique_ptr<Thunk>(new ForThunk(
GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence()));
return absl::make_unique<ForThunk>(GetThunkInfo(hlo), loop_limit,
ir_emitter_body.ConsumeThunkSequence());
}
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildConditionalThunk(
std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
const HloInstruction* hlo) {
// Check that the buffers used in conditional are shared with the operands and
// result appropriately.
@ -2288,17 +2096,15 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildConditionalThunk(
for (int j = 0; j < hlo->branch_count(); ++j) {
branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1)));
HloComputation* branch_computation = hlo->branch_computation(j);
TF_ASSIGN_OR_RETURN(
auto ir_emitter,
IrEmitterUnnested::Create(hlo_module_config_, branch_computation,
ir_emitter_context_));
TF_CHECK_OK(branch_computation->Accept(ir_emitter.get()));
branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence()));
IrEmitterUnnested ir_emitter(hlo_module_config_, branch_computation,
ir_emitter_context_);
TF_CHECK_OK(branch_computation->Accept(&ir_emitter));
branch_thunks.push_back(std::move(*ir_emitter.ConsumeThunkSequence()));
}
return std::unique_ptr<Thunk>(new ConditionalThunk(
return absl::make_unique<ConditionalThunk>(
GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands,
std::move(branch_thunks)));
std::move(branch_thunks));
}
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(

View File

@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
@ -29,40 +28,6 @@ limitations under the License.
namespace xla {
namespace gpu {
struct BufferSlice {
// The root buffer to look at.
BufferAllocation::Slice buffer_slice;
// Describes how to dereference starting at that buffer to get to the buffer
// in question.
ShapeIndex gte_index;
};
// Describes how to access a particular subshape for an HLO. For instance if
// `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at
// ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is
// found at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we
// dereference twice -- first at index 3, and then at index 4 -- to get the
// address of our buffer.
struct HloBufferSlice : public BufferSlice {
const HloInstruction* instr;
ShapeIndex hlo_index;
};
struct MlirBufferSlice : public BufferSlice {
// The buffer is modified by the kernel.
bool written;
Shape shape;
};
struct MlirEmitterInput {
mlir::Operation* op;
absl::string_view name;
Thunk::ThunkInfo thunk_info;
MlirBufferSlice extra_slice;
};
// Emits LLVM IR for an "unnested computation".
//
// An unnested computation is an HloComputation which you run by executing one
@ -124,14 +89,12 @@ class IrEmitterUnnested : public IrEmitter,
const string& loop_name, llvm::Value* tile_height,
llvm::Value* tile_width, KernelSupportLibrary* ksl)>;
IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
const HloComputation* hlo_computation,
IrEmitterContext* ir_emitter_context);
IrEmitterUnnested(const IrEmitterUnnested&) = delete;
IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
static StatusOr<std::unique_ptr<IrEmitterUnnested>> Create(
const HloModuleConfig& hlo_module_config,
const HloComputation* hlo_computation,
IrEmitterContext* ir_emitter_context);
// Transfers the ownship of thunk_sequence_ out.
std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
return std::make_unique<ThunkSequence>(std::move(thunk_sequence_));
@ -161,7 +124,6 @@ class IrEmitterUnnested : public IrEmitter,
Status HandleScatter(HloInstruction* scatter) override;
Status HandleSelect(HloInstruction* select) override;
Status HandleSort(HloInstruction* sort) override;
Status EmitMlirSort(MlirEmitterInput input);
Status HandleTriangularSolve(HloInstruction* hlo) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleAllReduce(HloInstruction* crs) override;
@ -186,10 +148,6 @@ class IrEmitterUnnested : public IrEmitter,
Status Postprocess(HloInstruction* hlo) override;
private:
IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
const HloComputation* hlo_computation,
IrEmitterContext* ir_emitter_context);
// Add a owning Thunk object to the thunk sequence.
void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) override {
thunk_sequence_.emplace_back(std::move(thunk));
@ -306,7 +264,8 @@ class IrEmitterUnnested : public IrEmitter,
// Builds the prototype of the IR kernel for `inst` and adds it to the module.
// This kernel takes as arguments pointers to the given buffer allocations.
llvm::Function* BuildKernelPrototype(
absl::string_view name, absl::Span<const BufferAllocation* const> args);
const HloInstruction& inst,
absl::Span<const BufferAllocation* const> args);
// Helper for writing extra outputs from inside a reduce kernel.
Status EmitExtraOutputsForReduce(
@ -531,12 +490,6 @@ class IrEmitterUnnested : public IrEmitter,
HloComputation* reducer, llvm::Type* element_type,
llvm::Value* partial_result_address);
std::unique_ptr<KernelThunk> BuildKernelThunkFromBufferSlices(
absl::string_view name, Thunk::ThunkInfo thunk_info,
absl::Span<const BufferSlice* const> slices,
std::function<void(const BufferSlice*, llvm::Value*)>
bind_slice_to_ir_value);
// Returns a KernelThunk that invokes the kernel emitted for `inst`. The
// caller needs to make sure `inst` outlives the lifetime of the returned
// Thunk object. 'implements_whole_instruction' specifies whether this
@ -545,11 +498,6 @@ class IrEmitterUnnested : public IrEmitter,
std::unique_ptr<KernelThunk> BuildKernelThunk(
const HloInstruction* inst, bool implements_whole_instruction);
std::unique_ptr<KernelThunk> BuildKernelThunkForMlir(
absl::string_view name, Thunk::ThunkInfo thunk_info,
absl::Span<const MlirBufferSlice> slices,
std::vector<llvm_ir::IrArray>* ir_arrays);
// Returns a thunk that, given a reduce or select-and-scatter op,
// initializes its memory to the appropriate initial value.
StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(
@ -557,18 +505,17 @@ class IrEmitterUnnested : public IrEmitter,
// Returns a WhileThunk that invokes thunk sequences for 'condition' and
// 'body' sub-computations of while instruction 'hlo'.
StatusOr<std::unique_ptr<Thunk>> BuildWhileThunk(const HloInstruction* hlo);
std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);
// Returns a ForThunk which executes 'loop_limit' invocations of a thunk
// sequence from the 'body' sub-computation of the while instruction 'hlo'.
StatusOr<std::unique_ptr<Thunk>> BuildForThunk(const HloInstruction* hlo,
const int64 loop_limit);
std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
const int64 loop_limit);
// Returns a ConditionalThunk which executes the thunk sequence for the
// 'branch_computation' corresponding to the predicate/branch_index of the
// given conditional instruction.
StatusOr<std::unique_ptr<Thunk>> BuildConditionalThunk(
const HloInstruction* hlo);
std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
// Emits current thread id with the given type.
//
@ -598,9 +545,6 @@ class IrEmitterUnnested : public IrEmitter,
absl::optional<int64> thread_id_filter = absl::nullopt,
absl::optional<int64> block_id_filter = absl::nullopt);
StatusOr<const HloComputation*> GetOrCreateSubComputationFromRegion(
mlir::Region* region);
// Returns the last generated thunk.
Thunk* LastThunk() const { return thunk_sequence_.back().get(); }
@ -611,14 +555,6 @@ class IrEmitterUnnested : public IrEmitter,
// The HloComputation that this IrEmitter emits code for.
const HloComputation* hlo_computation_;
mlir::OwningModuleRef mlir_scratch_module_;
// This is for cache-purpose only. It has no significant semantics.
mlir::LhloDialectEmitter lhlo_scratch_emitter_;
absl::flat_hash_map<const mlir::Region*, std::unique_ptr<HloModule>>
scratch_nested_computations_;
};
} // namespace gpu

View File

@ -458,35 +458,6 @@ xla_test(
],
)
tf_cc_test(
name = "sorting_test",
srcs = [
"sorting_test.cc",
],
tags = tf_cuda_tests_tags() + [
"no_rocm",
],
deps = [
":gpu_codegen_test",
"//tensorflow/compiler/xla:debug_options_flags",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_proto_cc",
"//tensorflow/compiler/xla/service:gpu_plugin",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/service/gpu:gpu_executable",
"//tensorflow/compiler/xla/tests:filecheck",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
"//tensorflow/core:lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/stream_executor/lib",
"@com_google_absl//absl/memory",
],
)
tf_cc_binary(
name = "hlo_to_llvm_ir",
srcs = ["hlo_to_llvm_ir.cc"],

View File

@ -8,162 +8,162 @@ compare {
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]])
// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]])
// CHECK-NEXT: entry:
// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64
// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64
// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]]
// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0
// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]*
// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0
// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64
// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]]
// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]])
// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2
// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2
// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK: sort.in_bounds-after:
// CHECK-NEXT: ret void
// CHECK: sort.in_bounds-true:
// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP8]], 2
// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1
// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]]
// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3
// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]]
// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2
// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1
// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]]
// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3
// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]]
// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
// CHECK: smaller_comparison_index-after:
// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]]
// CHECK: smaller_comparison_index-true:
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]]
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]]
// CHECK-NEXT: call void @region_0_4(float* [[TMP16]], float* [[TMP17]], i8* [[COMPARE_RETURN_BUFFER]])
// CHECK-NEXT: [[TMP18:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP18]], 0
// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: call void @compare(float* [[TMP12]], float* [[TMP13]], i8* [[COMPARE_RETURN_BUFFER]])
// CHECK-NEXT: [[TMP14:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP14]], 0
// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]]
// CHECK: is_smaller_than-after:
// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]]
// CHECK: is_smaller_than-true:
// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4
// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP17]], align 4
// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]]
// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4
// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]]
// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4
// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4
// CHECK-NEXT: [[TMP16:%.*]] = load float, float* [[TMP13]], align 4
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4
// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
// CHECK-NEXT: store float [[TMP16]], float* [[TMP18]], align 4
// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]]
// CHECK: define internal void @region_0_4(float* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]])
// CHECK: define internal void @compare(float* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]])
// CHECK-NEXT: entry:
// CHECK-NEXT: [[COMPARE_3_TYPED:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[ARG_0_1_TYPED:%.*]], align 4
// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARG_1_2_TYPED:%.*]], align 4
// CHECK-NEXT: [[LT_TYPED:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[P_0_LHS_TYPED]], align 4
// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[P_0_RHS_TYPED]], align 4
// CHECK-NEXT: [[TMP2:%.*]] = fcmp olt float [[TMP0]], [[TMP1]]
// CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i8
// CHECK-NEXT: store i8 [[TMP3]], i8* [[COMPARE_3_TYPED]], align 1
// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[COMPARE_3_TYPED]], align 1
// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG:%.*]], align 1
// CHECK-NEXT: store i8 [[TMP3]], i8* [[LT_TYPED]], align 1
// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[LT_TYPED]], align 1
// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG]], align 1
// CHECK-NEXT: ret void
// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) {
// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) {
// CHECK-NEXT: entry:
// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64
// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64
// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]]
// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0
// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]*
// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0
// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64
// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]]
// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]])
// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2
// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2
// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK: sort.in_bounds-after:
// CHECK-NEXT: ret void
// CHECK: sort.in_bounds-true:
// CHECK-NEXT: [[TMP11:%.*]] = xor i64 [[TMP8]], 3
// CHECK-NEXT: [[TMP12:%.*]] = icmp slt i64 [[TMP8]], [[TMP11]]
// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], 3
// CHECK-NEXT: [[TMP14:%.*]] = and i1 [[TMP12]], [[TMP13]]
// CHECK-NEXT: br i1 [[TMP14]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
// CHECK-NEXT: [[TMP7:%.*]] = xor i64 [[TMP4]], 3
// CHECK-NEXT: [[TMP8:%.*]] = icmp slt i64 [[TMP4]], [[TMP7]]
// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], 3
// CHECK-NEXT: [[TMP10:%.*]] = and i1 [[TMP8]], [[TMP9]]
// CHECK-NEXT: br i1 [[TMP10]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
// CHECK: smaller_comparison_index-after:
// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]]
// CHECK: smaller_comparison_index-true:
// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]]
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP8]]
// CHECK-NEXT: call void @region_0_4(float* [[TMP15]], float* [[TMP16]], i8* [[COMPARE_RETURN_BUFFER]])
// CHECK-NEXT: [[TMP17:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP17]], 0
// CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP4]]
// CHECK-NEXT: call void @compare(float* [[TMP11]], float* [[TMP12]], i8* [[COMPARE_RETURN_BUFFER]])
// CHECK-NEXT: [[TMP13:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP13]], 0
// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]]
// CHECK: is_smaller_than-after:
// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]]
// CHECK: is_smaller_than-true:
// CHECK-NEXT: [[TMP18:%.*]] = load float, float* [[TMP15]], align 4
// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4
// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP8]]
// CHECK-NEXT: store float [[TMP18]], float* [[TMP20]], align 4
// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]]
// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4
// CHECK-NEXT: [[TMP14:%.*]] = load float, float* [[TMP11]], align 4
// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP4]]
// CHECK-NEXT: store float [[TMP14]], float* [[TMP16]], align 4
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4
// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]]
// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) {
// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) {
// CHECK-NEXT: entry:
// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64
// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64
// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]]
// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]*
// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0
// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64
// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]]
// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]])
// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2
// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2
// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK: sort.in_bounds-after:
// CHECK-NEXT: ret void
// CHECK: sort.in_bounds-true:
// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP8]], 2
// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1
// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]]
// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3
// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]]
// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2
// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1
// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]]
// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3
// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]]
// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
// CHECK: smaller_comparison_index-after:
// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]]
// CHECK: smaller_comparison_index-true:
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]]
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]]
// CHECK-NEXT: call void @region_0_4(float* [[TMP16]], float* [[TMP17]], i8* [[COMPARE_RETURN_BUFFER]])
// CHECK-NEXT: [[TMP18:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP18]], 0
// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: call void @compare(float* [[TMP12]], float* [[TMP13]], i8* [[COMPARE_RETURN_BUFFER]])
// CHECK-NEXT: [[TMP14:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP14]], 0
// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]]
// CHECK: is_smaller_than-after:
// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]]
// CHECK: is_smaller_than-true:
// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4
// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP17]], align 4
// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]]
// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4
// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]]
// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4
// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4
// CHECK-NEXT: [[TMP16:%.*]] = load float, float* [[TMP13]], align 4
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4
// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
// CHECK-NEXT: store float [[TMP16]], float* [[TMP18]], align 4
// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]]
ENTRY main {
x = f32[2, 3] parameter(0)
@ -182,198 +182,210 @@ compare {
ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT
}
// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]])
// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]])
// CHECK-NEXT: entry:
// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]*
// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0
// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]*
// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64
// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64
// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]]
// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4]], i64 0
// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]*
// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0
// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]*
// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0
// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]*
// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2]], i64 0
// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]*
// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3]], i64 0
// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64
// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]]
// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]])
// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2
// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2
// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK: sort.in_bounds-after:
// CHECK-NEXT: ret void
// CHECK: sort.in_bounds-true:
// CHECK-NEXT: [[TMP13:%.*]] = mul i64 [[TMP10]], 2
// CHECK-NEXT: [[TMP14:%.*]] = xor i64 [[TMP13]], 1
// CHECK-NEXT: [[TMP15:%.*]] = icmp slt i64 [[TMP13]], [[TMP14]]
// CHECK-NEXT: [[TMP16:%.*]] = icmp slt i64 [[TMP14]], 3
// CHECK-NEXT: [[TMP17:%.*]] = and i1 [[TMP15]], [[TMP16]]
// CHECK-NEXT: br i1 [[TMP17]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2
// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1
// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]]
// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3
// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]]
// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
// CHECK: smaller_comparison_index-after:
// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]]
// CHECK: smaller_comparison_index-true:
// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP14]]
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP14]]
// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
// CHECK-NEXT: call void @region_0_6(i32* [[TMP18]], i32* [[TMP19]], float* [[TMP20]], float* [[TMP21]], i8* [[COMPARE_RETURN_BUFFER]])
// CHECK-NEXT: [[TMP22:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP22]], 0
// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: call void @compare(i32* [[TMP12]], i32* [[TMP13]], float* [[TMP14]], float* [[TMP15]], i8* [[COMPARE_RETURN_BUFFER]])
// CHECK-NEXT: [[TMP16:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP16]], 0
// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]]
// CHECK: is_smaller_than-after:
// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]]
// CHECK: is_smaller_than-true:
// CHECK-NEXT: [[TMP23:%.*]] = load i32, i32* [[TMP18]], align 4
// CHECK-NEXT: [[TMP24:%.*]] = load i32, i32* [[TMP19]], align 4
// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
// CHECK-NEXT: store i32 [[TMP23]], i32* [[TMP25]], align 4
// CHECK-NEXT: [[TMP26:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP14]]
// CHECK-NEXT: store i32 [[TMP24]], i32* [[TMP26]], align 4
// CHECK-NEXT: [[TMP27:%.*]] = load float, float* [[TMP20]], align 4
// CHECK-NEXT: [[TMP28:%.*]] = load float, float* [[TMP21]], align 4
// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
// CHECK-NEXT: store float [[TMP27]], float* [[TMP29]], align 4
// CHECK-NEXT: [[TMP30:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP14]]
// CHECK-NEXT: store float [[TMP28]], float* [[TMP30]], align 4
// CHECK-NEXT: [[TMP17:%.*]] = load i32, i32* [[TMP12]], align 4
// CHECK-NEXT: [[TMP18:%.*]] = load i32, i32* [[TMP13]], align 4
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: store i32 [[TMP17]], i32* [[TMP19]], align 4
// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
// CHECK-NEXT: store i32 [[TMP18]], i32* [[TMP20]], align 4
// CHECK-NEXT: [[TMP21:%.*]] = load float, float* [[TMP14]], align 4
// CHECK-NEXT: [[TMP22:%.*]] = load float, float* [[TMP15]], align 4
// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: store float [[TMP21]], float* [[TMP23]], align 4
// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
// CHECK-NEXT: store float [[TMP22]], float* [[TMP24]], align 4
// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]]
// CHECK: define internal void @region_0_6(i32* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], i32* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]])
// CHECK: define internal void @compare(i32* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], i32* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]])
// CHECK-NEXT: entry:
// CHECK-NEXT: [[COMPARE_5_TYPED:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[ARG_2_3_TYPED:%.*]], align 4
// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARG_3_4_TYPED:%.*]], align 4
// CHECK-NEXT: [[LT_TYPED:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[P_1_LHS_TYPED]], align 4
// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[P_1_RHS_TYPED]], align 4
// CHECK-NEXT: [[TMP2:%.*]] = fcmp olt float [[TMP0]], [[TMP1]]
// CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i8
// CHECK-NEXT: store i8 [[TMP3]], i8* [[COMPARE_5_TYPED]], align 1
// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[COMPARE_5_TYPED]], align 1
// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG:%.*]], align 1
// CHECK-NEXT: store i8 [[TMP3]], i8* [[LT_TYPED]], align 1
// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[LT_TYPED]], align 1
// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG]], align 1
// CHECK-NEXT: ret void
// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]])
// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]])
// CHECK-NEXT: entry:
// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]*
// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0
// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]*
// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64
// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64
// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]]
// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0
// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]*
// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]*
// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0
// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]*
// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2:%.*]], i64 0
// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]*
// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3:%.*]], i64 0
// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64
// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]]
// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]])
// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2
// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2
// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK: sort.in_bounds-after:
// CHECK-NEXT: ret void
// CHECK: sort.in_bounds-true:
// CHECK-NEXT: [[TMP13:%.*]] = xor i64 [[TMP10]], 3
// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP10]], [[TMP13]]
// CHECK-NEXT: [[TMP15:%.*]] = icmp slt i64 [[TMP13]], 3
// CHECK-NEXT: [[TMP16:%.*]] = and i1 [[TMP14]], [[TMP15]]
// CHECK-NEXT: br i1 [[TMP16]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
// CHECK-NEXT: [[TMP7:%.*]] = xor i64 [[TMP4]], 3
// CHECK-NEXT: [[TMP8:%.*]] = icmp slt i64 [[TMP4]], [[TMP7]]
// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], 3
// CHECK-NEXT: [[TMP10:%.*]] = and i1 [[TMP8]], [[TMP9]]
// CHECK-NEXT: br i1 [[TMP10]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
// CHECK: smaller_comparison_index-after:
// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]]
// CHECK: smaller_comparison_index-true:
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP10]]
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP10]]
// CHECK-NEXT: call void @region_0_6(i32* [[TMP17]], i32* [[TMP18]], float* [[TMP19]], float* [[TMP20]], i8* [[COMPARE_RETURN_BUFFER]])
// CHECK-NEXT: [[TMP21:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP21]], 0
// CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP4]]
// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP4]]
// CHECK-NEXT: call void @compare(i32* [[TMP11]], i32* [[TMP12]], float* [[TMP13]], float* [[TMP14]], i8* [[COMPARE_RETURN_BUFFER]])
// CHECK-NEXT: [[TMP15:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP15]], 0
// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]]
// CHECK: is_smaller_than-after:
// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]]
// CHECK: is_smaller_than-true:
// CHECK-NEXT: [[TMP22:%.*]] = load i32, i32* [[TMP17]], align 4
// CHECK-NEXT: [[TMP23:%.*]] = load i32, i32* [[TMP18]], align 4
// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP10]]
// CHECK-NEXT: store i32 [[TMP22]], i32* [[TMP24]], align 4
// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
// CHECK-NEXT: store i32 [[TMP23]], i32* [[TMP25]], align 4
// CHECK-NEXT: [[TMP26:%.*]] = load float, float* [[TMP19]], align 4
// CHECK-NEXT: [[TMP27:%.*]] = load float, float* [[TMP20]], align 4
// CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP10]]
// CHECK-NEXT: store float [[TMP26]], float* [[TMP28]], align 4
// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
// CHECK-NEXT: store float [[TMP27]], float* [[TMP29]], align 4
// CHECK-NEXT: [[TMP16:%.*]] = load i32, i32* [[TMP11]], align 4
// CHECK-NEXT: [[TMP17:%.*]] = load i32, i32* [[TMP12]], align 4
// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP4]]
// CHECK-NEXT: store i32 [[TMP16]], i32* [[TMP18]], align 4
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: store i32 [[TMP17]], i32* [[TMP19]], align 4
// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP13]], align 4
// CHECK-NEXT: [[TMP21:%.*]] = load float, float* [[TMP14]], align 4
// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP4]]
// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4
// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
// CHECK-NEXT: store float [[TMP21]], float* [[TMP23]], align 4
// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]]
// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]])
// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]])
// CHECK-NEXT: entry:
// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1
// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]*
// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0
// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]*
// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64
// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64
// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]]
// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0
// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]*
// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]*
// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0
// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]*
// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2:%.*]], i64 0
// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]*
// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3:%.*]], i64 0
// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]*
// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64
// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]]
// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]])
// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2
// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2
// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
// CHECK: sort.in_bounds-after:
// CHECK-NEXT: [[TMP13:%.*]] = bitcast [2 x [3 x i32]]* [[TMP1]] to i8*
// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 0
// CHECK-NEXT: store i8* [[TMP13]], i8** [[TMP14]], align 8
// CHECK-NEXT: [[TMP15:%.*]] = bitcast [2 x [3 x float]]* [[TMP3]] to i8*
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 1
// CHECK-NEXT: store i8* [[TMP15]], i8** [[TMP16]], align 8
// CHECK-NEXT: [[TMP7:%.*]] = bitcast [2 x [3 x i32]]* [[SORT_TYPED2]] to i8*
// CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[SORT_TYPED]], i64 0, i64 0
// CHECK-NEXT: store i8* [[TMP7]], i8** [[TMP8]], align 8
// CHECK-NEXT: [[TMP9:%.*]] = bitcast [2 x [3 x float]]* [[SORT_TYPED4]] to i8*
// CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[SORT_TYPED]], i64 0, i64 1
// CHECK-NEXT: store i8* [[TMP9]], i8** [[TMP10]], align 8
// CHECK-NEXT: ret void
// CHECK: sort.in_bounds-true:
// CHECK-NEXT: [[TMP17:%.*]] = mul i64 [[TMP10]], 2
// CHECK-NEXT: [[TMP18:%.*]] = xor i64 [[TMP17]], 1
// CHECK-NEXT: [[TMP19:%.*]] = icmp slt i64 [[TMP17]], [[TMP18]]
// CHECK-NEXT: [[TMP20:%.*]] = icmp slt i64 [[TMP18]], 3
// CHECK-NEXT: [[TMP21:%.*]] = and i1 [[TMP19]], [[TMP20]]
// CHECK-NEXT: br i1 [[TMP21]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP4]], 2
// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1
// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]]
// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3
// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]]
// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
// CHECK: smaller_comparison_index-after:
// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]]
// CHECK: smaller_comparison_index-true:
// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP18]]
// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP17]]
// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP18]]
// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP17]]
// CHECK-NEXT: call void @region_0_6(i32* [[TMP22]], i32* [[TMP23]], float* [[TMP24]], float* [[TMP25]], i8* [[COMPARE_RETURN_BUFFER]])
// CHECK-NEXT: [[TMP26:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP26]], 0
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP12]]
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP11]]
// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP12]]
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP11]]
// CHECK-NEXT: call void @compare(i32* [[TMP16]], i32* [[TMP17]], float* [[TMP18]], float* [[TMP19]], i8* [[COMPARE_RETURN_BUFFER]])
// CHECK-NEXT: [[TMP20:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP20]], 0
// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]]
// CHECK: is_smaller_than-after:
// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]]
// CHECK: is_smaller_than-true:
// CHECK-NEXT: [[TMP27:%.*]] = load i32, i32* [[TMP22]], align 4
// CHECK-NEXT: [[TMP28:%.*]] = load i32, i32* [[TMP23]], align 4
// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP17]]
// CHECK-NEXT: store i32 [[TMP27]], i32* [[TMP29]], align 4
// CHECK-NEXT: [[TMP30:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP18]]
// CHECK-NEXT: store i32 [[TMP28]], i32* [[TMP30]], align 4
// CHECK-NEXT: [[TMP31:%.*]] = load float, float* [[TMP24]], align 4
// CHECK-NEXT: [[TMP32:%.*]] = load float, float* [[TMP25]], align 4
// CHECK-NEXT: [[TMP33:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP17]]
// CHECK-NEXT: store float [[TMP31]], float* [[TMP33]], align 4
// CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP18]]
// CHECK-NEXT: store float [[TMP32]], float* [[TMP34]], align 4
// CHECK-NEXT: [[TMP21:%.*]] = load i32, i32* [[TMP16]], align 4
// CHECK-NEXT: [[TMP22:%.*]] = load i32, i32* [[TMP17]], align 4
// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP11]]
// CHECK-NEXT: store i32 [[TMP21]], i32* [[TMP23]], align 4
// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP12]]
// CHECK-NEXT: store i32 [[TMP22]], i32* [[TMP24]], align 4
// CHECK-NEXT: [[TMP25:%.*]] = load float, float* [[TMP18]], align 4
// CHECK-NEXT: [[TMP26:%.*]] = load float, float* [[TMP19]], align 4
// CHECK-NEXT: [[TMP27:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP11]]
// CHECK-NEXT: store float [[TMP25]], float* [[TMP27]], align 4
// CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP12]]
// CHECK-NEXT: store float [[TMP26]], float* [[TMP28]], align 4
// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]]
ENTRY main {
x = s32[2, 3] parameter(0)

View File

@ -1,71 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <utility>
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/filecheck.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/stream_executor/lib/statusor.h"
namespace xla {
namespace gpu {
namespace {
class SortingTest : public GpuCodegenTest {
protected:
HloModuleConfig ConfigWithoutLayoutAssignment() {
HloModuleConfig config;
auto debug_options = HloTestBase::GetDebugOptionsForTest();
// Disable layout_assignment to use the preassigned layouts.
debug_options.add_xla_disable_hlo_passes("layout-assignment");
config.set_debug_options(debug_options);
return config;
}
};
TEST_F(SortingTest, Regression1) {
const char* hlo_text = R"(
HloModule TestModule
compare {
p.0.lhs = f32[] parameter(0)
p.0.rhs = f32[] parameter(1)
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
}
ENTRY TestComputation {
x = f32[3, 2]{1, 0} parameter(0)
x.copy = f32[3, 2]{0, 1} copy(x)
ROOT sort = f32[3, 2]{0, 1} sort(x.copy), dimensions={1}, to_apply=compare
}
)";
EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, ErrorSpec{1e-5, 1e-5}));
}
} // namespace
} // namespace gpu
} // namespace xla

View File

@ -415,10 +415,9 @@ llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
return inst;
}
string IrName(absl::string_view a) {
std::string s(a);
s.erase(std::remove(s.begin(), s.end(), '%'), s.end());
return s;
string IrName(string a) {
a.erase(std::remove(a.begin(), a.end(), '%'), a.end());
return a;
}
string IrName(absl::string_view a, absl::string_view b) {

View File

@ -87,7 +87,7 @@ string DumpModuleToString(const llvm::Module& module);
// - joining all of the nonempty inputs by '.', and then
// - removing all '%'s.
//
string IrName(absl::string_view a);
string IrName(string a);
string IrName(absl::string_view a, absl::string_view b);
string IrName(const HloInstruction* a, absl::string_view b = "");