Roll back XLA/GPU LHLO sort emitter again
It breaks an internal msan enabled test. PiperOrigin-RevId: 326072372 Change-Id: I245525cefa4da88097725662c75ccb213a328f19
This commit is contained in:
parent
9bc641d16c
commit
0572b205b8
@ -83,9 +83,6 @@ StatusOr<llvm::SmallVector<AffineMap, 1>> GetPermutationIfAvailable(
|
||||
strides[dim] = accumulated_stride;
|
||||
accumulated_stride *= shape.dimensions(dim);
|
||||
}
|
||||
if (accumulated_stride == 0) {
|
||||
return llvm::SmallVector<AffineMap, 1>{};
|
||||
}
|
||||
return llvm::SmallVector<AffineMap, 1>{
|
||||
makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())};
|
||||
}
|
||||
|
@ -8,6 +8,6 @@ HloModule TestModule
|
||||
ENTRY TestComputation {
|
||||
x = f32[3, 2]{1,0} parameter(0)
|
||||
|
||||
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) {name = "copy.1"} : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> ()
|
||||
// CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> ()
|
||||
ROOT x.copy = f32[3, 2]{0,1} copy(x)
|
||||
}
|
||||
|
@ -34,6 +34,7 @@ limitations under the License.
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Pass/PassOptions.h" // from @llvm-project
|
||||
#include "mlir/Translation.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
@ -181,10 +182,7 @@ template <typename OpType>
|
||||
StatusOr<OpType> LhloDialectEmitter::CreateOpWithoutAttrs(
|
||||
HloInstruction* instr) {
|
||||
Location loc = getLocation(instr);
|
||||
std::pair<Identifier, Attribute> attrs[] = {
|
||||
{Identifier::get("name", builder_.getContext()),
|
||||
builder_.getStringAttr(instr->name())},
|
||||
};
|
||||
ArrayRef<std::pair<Identifier, Attribute>> attrs;
|
||||
ArrayRef<Type> rets{};
|
||||
|
||||
llvm::SmallVector<Value, 4> operands;
|
||||
@ -254,14 +252,15 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
StatusOr<lmhlo::SortOp> LhloDialectEmitter::EmitSortOp(HloInstruction* instr) {
|
||||
StatusOr<mlir::Operation*> LhloDialectEmitter::EmitSortOp(
|
||||
HloInstruction* instr) {
|
||||
TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs<lmhlo::SortOp>(instr));
|
||||
auto* sort_instr = ::xla::Cast<::xla::HloSortInstruction>(instr);
|
||||
sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension()));
|
||||
sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable()));
|
||||
TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion(
|
||||
*sort_instr->called_computations()[0], &sort.comparator(), &builder_));
|
||||
return sort;
|
||||
return sort.getOperation();
|
||||
}
|
||||
|
||||
Status LhloDialectEmitter::HandleSort(HloInstruction* instr) {
|
||||
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/StandardTypes.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
|
||||
@ -42,7 +41,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault {
|
||||
builder_(module.getContext()),
|
||||
i8_type_(builder_.getIntegerType(8)) {}
|
||||
|
||||
::xla::StatusOr<lmhlo::SortOp> EmitSortOp(::xla::HloInstruction* instr);
|
||||
::xla::StatusOr<mlir::Operation*> EmitSortOp(::xla::HloInstruction* instr);
|
||||
|
||||
private:
|
||||
template <typename OpType>
|
||||
|
@ -254,11 +254,6 @@ cc_library(
|
||||
":target_util",
|
||||
":thunk",
|
||||
":thunk_emitter",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"//tensorflow/compiler/mlir/xla:hlo_utils",
|
||||
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
|
||||
"//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo",
|
||||
"//tensorflow/compiler/mlir/xla:type_to_shape",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -296,8 +291,6 @@ cc_library(
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1166,7 +1159,6 @@ cc_library(
|
||||
":target_constants",
|
||||
":tree_reduction_rewriter",
|
||||
":variadic_op_splitter",
|
||||
"//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla",
|
||||
"//tensorflow/compiler/xla:protobuf_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
@ -1225,8 +1217,6 @@ cc_library(
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Core",
|
||||
"@llvm-project//mlir:AllPassesAndDialectsNoRegistration",
|
||||
"@llvm-project//mlir:IR",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -29,8 +29,6 @@ limitations under the License.
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IR/Verifier.h"
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/InitAllDialects.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/protobuf_util.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/all_reduce_combiner.h"
|
||||
@ -518,22 +516,15 @@ static Status CompileModuleToLlvmIrImpl(
|
||||
DumpHloModuleIfEnabled(*hlo_module, **buffer_assignment,
|
||||
"after_optimizations");
|
||||
|
||||
mlir::registerAllDialects();
|
||||
mlir::MLIRContext mlir_context;
|
||||
|
||||
IrEmitterContext ir_emitter_context(
|
||||
hlo_module, buffer_assignment->get(), platform_name, gpu_device_info,
|
||||
cuda_compute_capability, profile_index_map, &mlir_context,
|
||||
llvm_module->get());
|
||||
cuda_compute_capability, profile_index_map, llvm_module->get());
|
||||
|
||||
HloComputation* entry_computation = hlo_module->entry_computation();
|
||||
IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation,
|
||||
&ir_emitter_context);
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto ir_emitter,
|
||||
IrEmitterUnnested::Create(hlo_module->config(), entry_computation,
|
||||
&ir_emitter_context));
|
||||
|
||||
TF_RETURN_IF_ERROR(ir_emitter->EmitConstantGlobals());
|
||||
TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals());
|
||||
|
||||
{
|
||||
XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission");
|
||||
@ -542,10 +533,9 @@ static Status CompileModuleToLlvmIrImpl(
|
||||
ThunkSequence thunk_sequence;
|
||||
absl::Span<HloInstruction* const> order = hlo_schedule->ThunkLaunchOrder();
|
||||
for (HloInstruction* instruction : order) {
|
||||
TF_RETURN_IF_ERROR(instruction->Visit(ir_emitter.get()));
|
||||
TF_RETURN_IF_ERROR(ir_emitter->Postprocess(instruction));
|
||||
std::unique_ptr<ThunkSequence> thunks =
|
||||
ir_emitter->ConsumeThunkSequence();
|
||||
TF_RETURN_IF_ERROR(instruction->Visit(&ir_emitter));
|
||||
TF_RETURN_IF_ERROR(ir_emitter.Postprocess(instruction));
|
||||
std::unique_ptr<ThunkSequence> thunks = ir_emitter.ConsumeThunkSequence();
|
||||
|
||||
// The invariants between each input HloInstruction* and output Thunk* are
|
||||
// not all explicitly checked, but at least we can document them here:
|
||||
|
@ -117,11 +117,11 @@ static bool HasMeaningfulName(llvm::Value* value) {
|
||||
return false;
|
||||
}
|
||||
|
||||
llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value,
|
||||
llvm::IRBuilder<>* b) {
|
||||
llvm::Type* pointee_type =
|
||||
llvm_ir::ShapeToIrType(shape, b->GetInsertBlock()->getModule());
|
||||
|
||||
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
|
||||
ShapeIndexView shape_index,
|
||||
llvm::Value* ir_value) {
|
||||
llvm::Type* pointee_type = llvm_ir::ShapeToIrType(
|
||||
ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_);
|
||||
llvm::Type* dest_type = pointee_type->getPointerTo();
|
||||
|
||||
llvm::Value* typed_ir_value;
|
||||
@ -129,17 +129,9 @@ llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value,
|
||||
typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
|
||||
llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
|
||||
} else {
|
||||
typed_ir_value = b->CreatePointerBitCastOrAddrSpaceCast(
|
||||
typed_ir_value = b_->CreatePointerBitCastOrAddrSpaceCast(
|
||||
ir_value, pointee_type->getPointerTo());
|
||||
}
|
||||
return typed_ir_value;
|
||||
}
|
||||
|
||||
llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
|
||||
ShapeIndexView shape_index,
|
||||
llvm::Value* ir_value) {
|
||||
auto typed_ir_value = CastToTypedValue(
|
||||
ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_value, b_);
|
||||
if (!HasMeaningfulName(ir_value)) {
|
||||
ir_value->setName(llvm_ir::IrName(&hlo, "raw"));
|
||||
}
|
||||
|
@ -116,10 +116,6 @@ class HloToIrBindings {
|
||||
llvm::Value* temp_buffer_base_ = nullptr;
|
||||
};
|
||||
|
||||
// Converts `ir_value` with type i8* to a typed LLVM Value* based on `shape`.
|
||||
llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value,
|
||||
llvm::IRBuilder<>* b);
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_
|
||||
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
|
||||
@ -35,15 +34,13 @@ class IrEmitterContext {
|
||||
const HloModule* hlo_module, const BufferAssignment* buffer_assignment,
|
||||
std::string platform_name, GpuDeviceInfo gpu_device_info,
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability,
|
||||
const HloProfileIndexMap* profile_index_map,
|
||||
mlir::MLIRContext* mlir_context, llvm::Module* llvm_module)
|
||||
const HloProfileIndexMap* profile_index_map, llvm::Module* llvm_module)
|
||||
: hlo_module_(hlo_module),
|
||||
buffer_assignment_(buffer_assignment),
|
||||
platform_name_(std::move(platform_name)),
|
||||
gpu_device_info_(gpu_device_info),
|
||||
cuda_compute_capability_(cuda_compute_capability),
|
||||
profile_index_map_(profile_index_map),
|
||||
mlir_context_(mlir_context),
|
||||
llvm_module_(llvm_module) {}
|
||||
// Disallow copy and assign.
|
||||
IrEmitterContext(const IrEmitterContext&) = delete;
|
||||
@ -60,7 +57,6 @@ class IrEmitterContext {
|
||||
return cuda_compute_capability_;
|
||||
}
|
||||
const HloProfileIndexMap* profile_index_map() { return profile_index_map_; }
|
||||
mlir::MLIRContext* mlir_context() { return mlir_context_; }
|
||||
llvm::Module* llvm_module() { return llvm_module_; }
|
||||
NameUniquer* name_uniquer() { return &name_uniquer_; }
|
||||
|
||||
@ -71,7 +67,6 @@ class IrEmitterContext {
|
||||
GpuDeviceInfo gpu_device_info_;
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability_;
|
||||
const HloProfileIndexMap* profile_index_map_;
|
||||
mlir::MLIRContext* mlir_context_;
|
||||
llvm::Module* llvm_module_;
|
||||
NameUniquer name_uniquer_;
|
||||
};
|
||||
|
@ -37,13 +37,6 @@ limitations under the License.
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/Function.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
@ -151,86 +144,13 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
|
||||
llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
|
||||
}
|
||||
|
||||
const BufferAllocation* GetAllocation(
|
||||
mlir::BlockArgument func_arg, const BufferAssignment& buffer_assignment) {
|
||||
auto func_op =
|
||||
mlir::cast<mlir::FuncOp>(func_arg.getParentRegion()->getParentOp());
|
||||
int64 allocation_index = func_op
|
||||
.getArgAttrOfType<mlir::IntegerAttr>(
|
||||
func_arg.getArgNumber(), "lmhlo.alloc")
|
||||
.getValue()
|
||||
.getSExtValue();
|
||||
return &buffer_assignment.GetAllocation(allocation_index);
|
||||
}
|
||||
|
||||
StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
|
||||
mlir::Value v, const BufferAssignment& buffer_assignment) {
|
||||
int64 size = v.getType().cast<mlir::MemRefType>().getSizeInBits() / 8;
|
||||
|
||||
if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
|
||||
return BufferAllocation::Slice(GetAllocation(arg, buffer_assignment), 0,
|
||||
size);
|
||||
}
|
||||
|
||||
// We match two patterns here:
|
||||
// * v = ViewOp(arg);
|
||||
// * v = StaticMemRefCastOp(ViewOp(arg));
|
||||
if (mlir::Operation* op = v.getDefiningOp()) {
|
||||
if (auto cast = mlir::dyn_cast<mlir::lmhlo::StaticMemRefCastOp>(op)) {
|
||||
mlir::Value source = cast.getViewSource();
|
||||
op = source.getDefiningOp();
|
||||
if (!op) {
|
||||
return Unimplemented("StaticMemRefCastOp has to wrap an op");
|
||||
}
|
||||
}
|
||||
if (auto view = mlir::dyn_cast<mlir::ViewOp>(op)) {
|
||||
return BufferAllocation::Slice(
|
||||
GetAllocation(view.source().cast<mlir::BlockArgument>(),
|
||||
buffer_assignment),
|
||||
mlir::cast<mlir::ConstantOp>(view.byte_shift().getDefiningOp())
|
||||
.value()
|
||||
.cast<mlir::IntegerAttr>()
|
||||
.getValue()
|
||||
.getSExtValue(),
|
||||
size);
|
||||
}
|
||||
return Unimplemented("StaticMemRefCastOp has to wrap a ViewOp");
|
||||
}
|
||||
|
||||
return Unimplemented(
|
||||
"Operand has to be in the form of ViewOp(arg) or "
|
||||
"StaticMemRefCastOp(ViewOp(arg))");
|
||||
}
|
||||
|
||||
absl::string_view GetHloName(mlir::Operation* op) {
|
||||
if (auto attr = op->getAttrOfType<mlir::StringAttr>("name")) {
|
||||
auto ref = attr.getValue();
|
||||
return absl::string_view(ref.data(), ref.size());
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
|
||||
const HloComputation* hlo_computation,
|
||||
IrEmitterContext* ir_emitter_context)
|
||||
: IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false),
|
||||
hlo_computation_(hlo_computation),
|
||||
mlir_scratch_module_(mlir::ModuleOp::create(
|
||||
mlir::Builder(ir_emitter_context->mlir_context()).getUnknownLoc())),
|
||||
lhlo_scratch_emitter_(ir_emitter_context_->buffer_assignment(),
|
||||
*hlo_computation, mlir_scratch_module_.get()) {}
|
||||
|
||||
StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
|
||||
const HloModuleConfig& hlo_module_config,
|
||||
const HloComputation* hlo_computation,
|
||||
IrEmitterContext* ir_emitter_context) {
|
||||
auto emitter = std::unique_ptr<IrEmitterUnnested>(new IrEmitterUnnested(
|
||||
hlo_module_config, hlo_computation, ir_emitter_context));
|
||||
TF_RETURN_IF_ERROR(emitter->lhlo_scratch_emitter_.Initialize());
|
||||
return std::move(emitter);
|
||||
}
|
||||
hlo_computation_(hlo_computation) {}
|
||||
|
||||
Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
|
||||
bindings_.UnbindAllLocalIrValues();
|
||||
@ -238,11 +158,12 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
|
||||
}
|
||||
|
||||
llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
|
||||
absl::string_view name, absl::Span<const BufferAllocation* const> args) {
|
||||
const HloInstruction& inst,
|
||||
absl::Span<const BufferAllocation* const> args) {
|
||||
// Compute the kernel name. The opcode string may contain "-" which cannot be
|
||||
// in a PTX function name, so sanitize the name before uniquifying it.
|
||||
string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
|
||||
llvm_ir::SanitizeFunctionName(std::string(name)));
|
||||
llvm_ir::SanitizeFunctionName(inst.name()));
|
||||
|
||||
// Create the kernel and add it to the module.
|
||||
llvm::Module* module = ir_emitter_context_->llvm_module();
|
||||
@ -438,8 +359,7 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
|
||||
TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional));
|
||||
AddThunkToThunkSequence(std::move(thunk));
|
||||
AddThunkToThunkSequence(BuildConditionalThunk(conditional));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1118,13 +1038,10 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
|
||||
// Build ForThunk for conformant while loops, otherwise build WhileThunk.
|
||||
auto config = xla_while->backend_config<WhileLoopBackendConfig>();
|
||||
if (config.ok() && config.ValueOrDie().has_known_trip_count()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto thunk,
|
||||
AddThunkToThunkSequence(
|
||||
BuildForThunk(xla_while, config.ValueOrDie().known_trip_count().n()));
|
||||
AddThunkToThunkSequence(std::move(thunk));
|
||||
} else {
|
||||
TF_ASSIGN_OR_RETURN(auto thunk, BuildWhileThunk(xla_while));
|
||||
AddThunkToThunkSequence(std::move(thunk));
|
||||
AddThunkToThunkSequence(BuildWhileThunk(xla_while));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1347,109 +1264,39 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
|
||||
return IrEmitter::HandleSelect(select);
|
||||
}
|
||||
|
||||
StatusOr<const HloComputation*>
|
||||
IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region) {
|
||||
std::unique_ptr<HloModule>& module = scratch_nested_computations_[region];
|
||||
if (module == nullptr) {
|
||||
xla::XlaComputation xla_computation;
|
||||
TF_RETURN_IF_ERROR(ConvertRegionToComputation(region, &xla_computation));
|
||||
TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
module, HloModule::CreateFromProto(xla_computation.proto(),
|
||||
HloModuleConfig(program_shape)));
|
||||
}
|
||||
return module->entry_computation();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
MlirEmitterInput result;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto sort_op, lhlo_scratch_emitter_.EmitSortOp(sort));
|
||||
result.op = sort_op;
|
||||
result.name = GetHloName(sort_op);
|
||||
// The name in sort op has no semantics, and it's for debug only. If the name
|
||||
// doesn't exist, we should use a namer (e.g. count-based).
|
||||
// TODO(timshen): use a namer instead of relying on the HloInstruction names.
|
||||
if (result.name.empty()) {
|
||||
result.name = sort->name();
|
||||
}
|
||||
const auto& buffer_assignment = ir_emitter_context_->buffer_assignment();
|
||||
auto& slice = result.extra_slice;
|
||||
TF_ASSIGN_OR_RETURN(slice.buffer_slice,
|
||||
buffer_assignment.GetUniqueSlice(sort, {}));
|
||||
slice.written = true;
|
||||
slice.shape = sort->shape();
|
||||
|
||||
result.thunk_info = GetThunkInfo(sort);
|
||||
|
||||
return EmitMlirSort(result);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) {
|
||||
const auto& buffer_assignment = ir_emitter_context_->buffer_assignment();
|
||||
auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(input.op);
|
||||
|
||||
int operand_count = sort_op.operands().size();
|
||||
std::vector<xla::Shape> operand_shapes(operand_count);
|
||||
std::vector<MlirBufferSlice> slices;
|
||||
std::vector<xla::Shape> output_shapes(sort_op.output().size());
|
||||
|
||||
for (int i = 0; i < operand_count; i++) {
|
||||
operand_shapes[i] =
|
||||
TypeToShape(sort_op.operands()[i].getType().cast<mlir::MemRefType>());
|
||||
}
|
||||
|
||||
// Craft n + 1 slices, where the first n are output parameters, and the last
|
||||
// is the on-device tuple storage. We don't need n operands because sorting
|
||||
// kernels are always in-place.
|
||||
for (int i = 0; i < operand_count; i++) {
|
||||
output_shapes[i] =
|
||||
TypeToShape(sort_op.output()[i].getType().cast<mlir::MemRefType>());
|
||||
MlirBufferSlice slice;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
slice.buffer_slice,
|
||||
GetAllocationSliceForMlir(sort_op.output()[i], buffer_assignment));
|
||||
slice.written = true;
|
||||
slice.shape = operand_shapes[i];
|
||||
slices.push_back(slice);
|
||||
}
|
||||
slices.push_back(input.extra_slice);
|
||||
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
|
||||
Shape keys_shape = operand_shapes[0];
|
||||
int64 dimension_to_sort = sort_op.dimension().getSExtValue();
|
||||
for (int64 i = 0; i < operand_count; ++i) {
|
||||
Shape keys_shape = sort->operand(0)->shape();
|
||||
int64 dimension_to_sort = sort->dimensions(0);
|
||||
for (int64 i = 0; i < sort->operand_count(); ++i) {
|
||||
ShapeIndex shape_index =
|
||||
sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
|
||||
// We assume that the layout of all involved operands and outputs is the
|
||||
// same.
|
||||
TF_RET_CHECK(
|
||||
LayoutUtil::LayoutsInShapesEqual(keys_shape, operand_shapes[i]));
|
||||
TF_RET_CHECK(
|
||||
LayoutUtil::LayoutsInShapesEqual(keys_shape, output_shapes[i]));
|
||||
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape,
|
||||
sort->operand(i)->shape()));
|
||||
TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
|
||||
keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
|
||||
|
||||
// If possible, we share buffers. If that is not possible, we need to copy
|
||||
// the values, because the emitter does the sorting in-place.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto destination_buffer,
|
||||
GetAllocationSliceForMlir(sort_op.output()[i], buffer_assignment));
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto source_address,
|
||||
GetAllocationSliceForMlir(sort_op.operands()[i], buffer_assignment));
|
||||
auto destination_buffer = GetAllocationSlice(*sort, shape_index);
|
||||
auto source_address = GetAllocationSlice(*sort->operand(i));
|
||||
if (destination_buffer != source_address) {
|
||||
// TODO(b/26783907): Figure out why we never seem to share buffers for
|
||||
// key/value sort.
|
||||
VLOG(2) << input.name << " requires initial D2D copy for operand " << i;
|
||||
VLOG(2) << sort->name() << " requires initial D2D copy for operand " << i;
|
||||
thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
|
||||
Thunk::ThunkInfo(),
|
||||
/*source_address=*/source_address,
|
||||
/*destination_buffer=*/destination_buffer,
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(operand_shapes[i])));
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape())));
|
||||
}
|
||||
}
|
||||
|
||||
uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
|
||||
int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
|
||||
VLOG(2) << input.name << " requires " << num_stages << " stages.";
|
||||
VLOG(2) << sort->name() << " requires " << num_stages << " stages.";
|
||||
CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
|
||||
CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound);
|
||||
|
||||
@ -1513,10 +1360,10 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) {
|
||||
// we have not enough threads, or not enough shared memory. Also it does not
|
||||
// give a speedup if the tile size is < 128.
|
||||
int64 total_shared_memory_needed = 0;
|
||||
for (int64 i = 0; i < operand_count; ++i) {
|
||||
for (int64 i = 0; i < sort->operand_count(); ++i) {
|
||||
total_shared_memory_needed +=
|
||||
kTileSize *
|
||||
ShapeUtil::ByteSizeOfPrimitiveType(operand_shapes[i].element_type());
|
||||
kTileSize * ShapeUtil::ByteSizeOfPrimitiveType(
|
||||
sort->operand(i)->shape().element_type());
|
||||
}
|
||||
bool no_tiling =
|
||||
kTileSize < 128 ||
|
||||
@ -1529,7 +1376,7 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) {
|
||||
"kTileSize=%d < 128, "
|
||||
"kThreadsPerBlock=%d > threads_per_block_limit=%d, "
|
||||
"total_shared_memory_needed=%d > shared_memory_per_block=%d",
|
||||
input.name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock,
|
||||
sort->name(), (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock,
|
||||
ir_emitter_context_->gpu_device_info().threads_per_block_limit,
|
||||
total_shared_memory_needed,
|
||||
ir_emitter_context_->gpu_device_info().shared_memory_per_block);
|
||||
@ -1537,38 +1384,37 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) {
|
||||
uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock);
|
||||
LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock);
|
||||
VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block",
|
||||
input.name, num_blocks, kThreadsPerBlock);
|
||||
sort->name(), num_blocks, kThreadsPerBlock);
|
||||
|
||||
std::vector<llvm_ir::IrArray> ir_arrays;
|
||||
auto emit_kernel = [&](absl::Span<const int64> xor_masks) {
|
||||
VLOG(2) << absl::StreamFormat(
|
||||
"%s uses kernel for xor masks [%s]", input.name,
|
||||
"%s uses kernel for xor masks [%s]", sort->name(),
|
||||
absl::StrJoin(xor_masks, ", ", [](std::string* out, int64 xor_mask) {
|
||||
absl::StrAppendFormat(out, "0x%x", xor_mask);
|
||||
}));
|
||||
thunks.push_back(BuildKernelThunkForMlir(input.name, Thunk::ThunkInfo(),
|
||||
slices, &ir_arrays));
|
||||
thunks.push_back(
|
||||
BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
|
||||
LaunchDimensions launch_dimensions = xor_masks.size() > 1
|
||||
? tiled_launch_dimensions
|
||||
: standard_launch_dimensions;
|
||||
UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
|
||||
ir_emitter_context_->llvm_module());
|
||||
std::vector<IrArray> values_arrays;
|
||||
values_arrays.reserve(operand_count);
|
||||
for (int64 i = 0; i < operand_count; ++i) {
|
||||
values_arrays.push_back(ir_arrays[i]);
|
||||
values_arrays.reserve(sort->operand_count());
|
||||
for (int64 i = 0; i < sort->operand_count(); ++i) {
|
||||
ShapeIndex shape_index =
|
||||
sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
|
||||
values_arrays.push_back(GetIrArray(*sort, *sort, shape_index));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
const HloComputation* comparator,
|
||||
GetOrCreateSubComputationFromRegion(&sort_op.comparator()));
|
||||
return llvm_ir::EmitSortInPlace(
|
||||
dimension_to_sort, values_arrays, IrName(input.name), xor_masks, &b_,
|
||||
dimension_to_sort, values_arrays, IrName(sort), xor_masks, &b_,
|
||||
launch_dimensions,
|
||||
xor_masks.size() > 1 ? num_iterations_in_sort_dim
|
||||
: standard_num_iterations_in_sort_dim,
|
||||
kTileSize,
|
||||
[&](absl::Span<llvm::Value* const> operands, llvm::Value* output) {
|
||||
return EmitCallToNestedComputation(*comparator, operands, output);
|
||||
return EmitCallToNestedComputation(*sort->to_apply(), operands,
|
||||
output);
|
||||
});
|
||||
};
|
||||
std::vector<int64> xor_masks;
|
||||
@ -1595,18 +1441,17 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) {
|
||||
TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
|
||||
}
|
||||
VLOG(2) << absl::StreamFormat(
|
||||
"%s requires %d thunks (including any D2D copies)", input.name,
|
||||
"%s requires %d thunks (including any D2D copies)", sort->name(),
|
||||
thunks.size());
|
||||
|
||||
AddThunkToThunkSequence(
|
||||
absl::make_unique<SequentialThunk>(input.thunk_info, std::move(thunks)));
|
||||
if (operand_count > 1) {
|
||||
AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
|
||||
GetThunkInfo(sort), std::move(thunks)));
|
||||
if (sort->operand_count() > 1) {
|
||||
// Emit the tuple as part of the last stage of sorting.
|
||||
// We are currently in the block sorted.in_bounds.after.
|
||||
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
|
||||
llvm_ir::EmitTuple(
|
||||
ir_arrays[operand_count],
|
||||
absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_);
|
||||
llvm_ir::EmitTuple(GetIrArray(*sort, *sort),
|
||||
ConstructIrArrayForOutputs(*sort), &b_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1744,6 +1589,24 @@ Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Describes how to access a particular subshape for an HLO. For instance if
|
||||
// `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at
|
||||
// ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is found
|
||||
// at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we
|
||||
// dereference twice -- first at index 3, and then at index 4 -- to get the
|
||||
// address of our buffer.
|
||||
struct HloBufferSlice {
|
||||
const HloInstruction* instr;
|
||||
ShapeIndex hlo_index;
|
||||
|
||||
// The root buffer to look at.
|
||||
BufferAllocation::Slice buffer_slice;
|
||||
|
||||
// Describes how to dereference starting at that buffer to get to the buffer
|
||||
// in question.
|
||||
ShapeIndex gte_index;
|
||||
};
|
||||
|
||||
// Figures out how to access the buffers for all subshapes of hlo's operands and
|
||||
// for hlo itself (i.e. all the buffers produced by HLO).
|
||||
//
|
||||
@ -1852,22 +1715,22 @@ static std::vector<HloBufferSlice> GetHloBufferSlices(
|
||||
return result;
|
||||
}
|
||||
|
||||
std::unique_ptr<KernelThunk>
|
||||
IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
|
||||
absl::string_view name, Thunk::ThunkInfo thunk_info,
|
||||
absl::Span<const BufferSlice* const> slices,
|
||||
std::function<void(const BufferSlice*, llvm::Value*)>
|
||||
bind_slice_to_ir_value) {
|
||||
const auto& buffer_assn = ir_emitter_context_->buffer_assignment();
|
||||
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
|
||||
const HloInstruction* inst, bool implements_whole_instruction) {
|
||||
const BufferAssignment& buffer_assn =
|
||||
ir_emitter_context_->buffer_assignment();
|
||||
|
||||
std::vector<HloBufferSlice> hlo_slices =
|
||||
GetHloBufferSlices(inst, buffer_assn);
|
||||
|
||||
// Figure out which buffer allocations need to be passed as arguments to our
|
||||
// kernel. This is simply all of the allocations referenced in slices,
|
||||
// kernel. This is simply all of the allocations referenced in hlo_slices,
|
||||
// plus the XLA temp buffer (if we have it). We always include the temp
|
||||
// buffer because even if the kernel itself doesn't use it, a nested
|
||||
// subcomputation within the kernel (e.g. a kMap's computation) might.
|
||||
std::unordered_set<const BufferAllocation*> buffers_needed;
|
||||
for (auto* slice : slices) {
|
||||
buffers_needed.insert(slice->buffer_slice.allocation());
|
||||
for (const auto& hlo_buffer_slice : hlo_slices) {
|
||||
buffers_needed.insert(hlo_buffer_slice.buffer_slice.allocation());
|
||||
}
|
||||
absl::optional<const BufferAllocation*> temp_buffer;
|
||||
for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
|
||||
@ -1896,7 +1759,7 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
|
||||
return a->index() < b->index();
|
||||
});
|
||||
|
||||
llvm::Function* kernel = BuildKernelPrototype(name, non_constant_buffers);
|
||||
llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers);
|
||||
|
||||
// Build a map from a BufferAllocation to the corresponding argument in our
|
||||
// kernel.
|
||||
@ -1930,19 +1793,24 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
|
||||
|
||||
// For each buffer our kernel might want to touch, bind it to a value derived
|
||||
// from our kernel args.
|
||||
for (auto* slice : slices) {
|
||||
const BufferAllocation::Slice& buffer_slice = slice->buffer_slice;
|
||||
const ShapeIndex& gte_index = slice->gte_index;
|
||||
for (const auto& hlo_buffer_slice : hlo_slices) {
|
||||
const HloInstruction* instr = hlo_buffer_slice.instr;
|
||||
const ShapeIndex& index = hlo_buffer_slice.hlo_index;
|
||||
const BufferAllocation::Slice& slice = hlo_buffer_slice.buffer_slice;
|
||||
const ShapeIndex& gte_index = hlo_buffer_slice.gte_index;
|
||||
|
||||
VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString()
|
||||
<< " is found in slice " << slice.ToString() << " at GTE index "
|
||||
<< gte_index.ToString();
|
||||
|
||||
llvm::Value* loc;
|
||||
if (buffer_slice.allocation()->is_constant()) {
|
||||
if (slice.allocation()->is_constant()) {
|
||||
loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
|
||||
llvm_ir::ConstantBufferAllocationToGlobalName(
|
||||
*buffer_slice.allocation()));
|
||||
llvm_ir::ConstantBufferAllocationToGlobalName(*slice.allocation()));
|
||||
CHECK_NE(loc, nullptr);
|
||||
} else {
|
||||
loc = InBoundsGEP(kernel_args.at(buffer_slice.allocation()),
|
||||
{b_.getInt64(buffer_slice.offset())});
|
||||
loc = InBoundsGEP(kernel_args.at(slice.allocation()),
|
||||
{b_.getInt64(slice.offset())});
|
||||
}
|
||||
|
||||
// If gte_index is nonempty, we have to dereference `loc` to get to the
|
||||
@ -1954,7 +1822,7 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
|
||||
loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)}));
|
||||
}
|
||||
|
||||
bind_slice_to_ir_value(slice, loc);
|
||||
bindings_.BindHloToIrValue(*instr, loc, index);
|
||||
}
|
||||
|
||||
// Bind the temp buffer so that nested subcomputations can find it if they
|
||||
@ -1966,66 +1834,9 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
|
||||
llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
|
||||
}
|
||||
|
||||
return absl::make_unique<KernelThunk>(thunk_info, non_constant_buffers,
|
||||
std::string(kernel->getName()));
|
||||
}
|
||||
|
||||
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
|
||||
const HloInstruction* inst, bool implements_whole_instruction) {
|
||||
std::vector<HloBufferSlice> hlo_slices =
|
||||
GetHloBufferSlices(inst, ir_emitter_context_->buffer_assignment());
|
||||
|
||||
std::vector<BufferSlice*> slice_ptrs;
|
||||
slice_ptrs.reserve(hlo_slices.size());
|
||||
for (auto& slice : hlo_slices) {
|
||||
slice_ptrs.push_back(&slice);
|
||||
}
|
||||
|
||||
return BuildKernelThunkFromBufferSlices(
|
||||
inst->name(),
|
||||
return absl::make_unique<KernelThunk>(
|
||||
implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(),
|
||||
slice_ptrs, [this](const BufferSlice* slice, llvm::Value* value) {
|
||||
const HloBufferSlice* hlo_buffer_slice =
|
||||
static_cast<const HloBufferSlice*>(slice);
|
||||
const HloInstruction* instr = hlo_buffer_slice->instr;
|
||||
const ShapeIndex& index = hlo_buffer_slice->hlo_index;
|
||||
VLOG(3) << "Buffer for " << instr->ToString() << " at "
|
||||
<< index.ToString() << " is found in slice "
|
||||
<< hlo_buffer_slice->buffer_slice.ToString() << " at GTE index "
|
||||
<< hlo_buffer_slice->gte_index.ToString();
|
||||
|
||||
bindings_.BindHloToIrValue(*instr, value, index);
|
||||
});
|
||||
}
|
||||
|
||||
std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunkForMlir(
|
||||
absl::string_view name, Thunk::ThunkInfo thunk_info,
|
||||
absl::Span<const MlirBufferSlice> slices,
|
||||
std::vector<llvm_ir::IrArray>* ir_arrays) {
|
||||
absl::flat_hash_set<BufferAllocation::Slice> buffers_written;
|
||||
std::vector<const BufferSlice*> slice_ptrs;
|
||||
slice_ptrs.reserve(slices.size());
|
||||
for (auto& slice : slices) {
|
||||
slice_ptrs.push_back(&slice);
|
||||
if (slice.written) {
|
||||
buffers_written.insert(slice.buffer_slice);
|
||||
}
|
||||
}
|
||||
|
||||
ir_arrays->clear();
|
||||
return BuildKernelThunkFromBufferSlices(
|
||||
name, thunk_info, slice_ptrs,
|
||||
[&](const BufferSlice* slice, llvm::Value* value) {
|
||||
const auto& mlir_slice = static_cast<const MlirBufferSlice&>(*slice);
|
||||
|
||||
llvm_ir::IrArray ir_array(
|
||||
CastToTypedValue(mlir_slice.shape, value, &b_), mlir_slice.shape);
|
||||
if (!buffers_written.contains(slice->buffer_slice)) {
|
||||
ir_array.MarkInvariantOverWholeProgram(&value->getContext());
|
||||
}
|
||||
|
||||
ir_arrays->push_back(ir_array);
|
||||
});
|
||||
non_constant_buffers, std::string(kernel->getName()));
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
|
||||
@ -2232,7 +2043,7 @@ Status CheckConditionalBuffersShareAllocation(
|
||||
|
||||
} // namespace
|
||||
|
||||
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
|
||||
std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
|
||||
const HloInstruction* hlo) {
|
||||
// Check that all while-related buffers share an allocation.
|
||||
TF_CHECK_OK(CheckWhileBuffersShareAllocation(
|
||||
@ -2240,26 +2051,24 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
|
||||
|
||||
// Generate thunk sequence for while 'condition'.
|
||||
HloComputation* condition = hlo->while_condition();
|
||||
TF_ASSIGN_OR_RETURN(auto ir_emitter_condition,
|
||||
IrEmitterUnnested::Create(hlo_module_config_, condition,
|
||||
ir_emitter_context_));
|
||||
TF_RETURN_IF_ERROR(condition->Accept(ir_emitter_condition.get()));
|
||||
IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition,
|
||||
ir_emitter_context_);
|
||||
TF_CHECK_OK(condition->Accept(&ir_emitter_condition));
|
||||
|
||||
// Generate thunk sequence for while 'body'.
|
||||
HloComputation* body = hlo->while_body();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto ir_emitter_body,
|
||||
IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
|
||||
TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
|
||||
IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
|
||||
ir_emitter_context_);
|
||||
TF_CHECK_OK(body->Accept(&ir_emitter_body));
|
||||
|
||||
return std::unique_ptr<Thunk>(new WhileThunk(
|
||||
return absl::make_unique<WhileThunk>(
|
||||
GetThunkInfo(hlo),
|
||||
GetAllocationSlice(*condition->root_instruction()), // cond result
|
||||
ir_emitter_condition->ConsumeThunkSequence(),
|
||||
ir_emitter_body->ConsumeThunkSequence()));
|
||||
ir_emitter_condition.ConsumeThunkSequence(),
|
||||
ir_emitter_body.ConsumeThunkSequence());
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
|
||||
std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
|
||||
const HloInstruction* hlo, const int64 loop_limit) {
|
||||
// Check that all while-related buffers share an allocation.
|
||||
TF_CHECK_OK(CheckWhileBuffersShareAllocation(
|
||||
@ -2267,16 +2076,15 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
|
||||
|
||||
// Generate thunk sequence for while 'body' (will be used a For loop body).
|
||||
HloComputation* body = hlo->while_body();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto ir_emitter_body,
|
||||
IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
|
||||
TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
|
||||
IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
|
||||
ir_emitter_context_);
|
||||
TF_CHECK_OK(body->Accept(&ir_emitter_body));
|
||||
|
||||
return std::unique_ptr<Thunk>(new ForThunk(
|
||||
GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence()));
|
||||
return absl::make_unique<ForThunk>(GetThunkInfo(hlo), loop_limit,
|
||||
ir_emitter_body.ConsumeThunkSequence());
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildConditionalThunk(
|
||||
std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
|
||||
const HloInstruction* hlo) {
|
||||
// Check that the buffers used in conditional are shared with the operands and
|
||||
// result appropriately.
|
||||
@ -2288,17 +2096,15 @@ StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildConditionalThunk(
|
||||
for (int j = 0; j < hlo->branch_count(); ++j) {
|
||||
branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1)));
|
||||
HloComputation* branch_computation = hlo->branch_computation(j);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto ir_emitter,
|
||||
IrEmitterUnnested::Create(hlo_module_config_, branch_computation,
|
||||
ir_emitter_context_));
|
||||
TF_CHECK_OK(branch_computation->Accept(ir_emitter.get()));
|
||||
branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence()));
|
||||
IrEmitterUnnested ir_emitter(hlo_module_config_, branch_computation,
|
||||
ir_emitter_context_);
|
||||
TF_CHECK_OK(branch_computation->Accept(&ir_emitter));
|
||||
branch_thunks.push_back(std::move(*ir_emitter.ConsumeThunkSequence()));
|
||||
}
|
||||
|
||||
return std::unique_ptr<Thunk>(new ConditionalThunk(
|
||||
return absl::make_unique<ConditionalThunk>(
|
||||
GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands,
|
||||
std::move(branch_thunks)));
|
||||
std::move(branch_thunks));
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
|
||||
|
@ -17,7 +17,6 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_
|
||||
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
|
||||
@ -29,40 +28,6 @@ limitations under the License.
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
struct BufferSlice {
|
||||
// The root buffer to look at.
|
||||
BufferAllocation::Slice buffer_slice;
|
||||
|
||||
// Describes how to dereference starting at that buffer to get to the buffer
|
||||
// in question.
|
||||
ShapeIndex gte_index;
|
||||
};
|
||||
|
||||
// Describes how to access a particular subshape for an HLO. For instance if
|
||||
// `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at
|
||||
// ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is
|
||||
// found at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we
|
||||
// dereference twice -- first at index 3, and then at index 4 -- to get the
|
||||
// address of our buffer.
|
||||
struct HloBufferSlice : public BufferSlice {
|
||||
const HloInstruction* instr;
|
||||
ShapeIndex hlo_index;
|
||||
};
|
||||
|
||||
struct MlirBufferSlice : public BufferSlice {
|
||||
// The buffer is modified by the kernel.
|
||||
bool written;
|
||||
|
||||
Shape shape;
|
||||
};
|
||||
|
||||
struct MlirEmitterInput {
|
||||
mlir::Operation* op;
|
||||
absl::string_view name;
|
||||
Thunk::ThunkInfo thunk_info;
|
||||
MlirBufferSlice extra_slice;
|
||||
};
|
||||
|
||||
// Emits LLVM IR for an "unnested computation".
|
||||
//
|
||||
// An unnested computation is an HloComputation which you run by executing one
|
||||
@ -124,14 +89,12 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
const string& loop_name, llvm::Value* tile_height,
|
||||
llvm::Value* tile_width, KernelSupportLibrary* ksl)>;
|
||||
|
||||
IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
|
||||
const HloComputation* hlo_computation,
|
||||
IrEmitterContext* ir_emitter_context);
|
||||
IrEmitterUnnested(const IrEmitterUnnested&) = delete;
|
||||
IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete;
|
||||
|
||||
static StatusOr<std::unique_ptr<IrEmitterUnnested>> Create(
|
||||
const HloModuleConfig& hlo_module_config,
|
||||
const HloComputation* hlo_computation,
|
||||
IrEmitterContext* ir_emitter_context);
|
||||
|
||||
// Transfers the ownship of thunk_sequence_ out.
|
||||
std::unique_ptr<ThunkSequence> ConsumeThunkSequence() {
|
||||
return std::make_unique<ThunkSequence>(std::move(thunk_sequence_));
|
||||
@ -161,7 +124,6 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status HandleScatter(HloInstruction* scatter) override;
|
||||
Status HandleSelect(HloInstruction* select) override;
|
||||
Status HandleSort(HloInstruction* sort) override;
|
||||
Status EmitMlirSort(MlirEmitterInput input);
|
||||
Status HandleTriangularSolve(HloInstruction* hlo) override;
|
||||
Status HandleTupleSelect(HloInstruction* tuple_select) override;
|
||||
Status HandleAllReduce(HloInstruction* crs) override;
|
||||
@ -186,10 +148,6 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
Status Postprocess(HloInstruction* hlo) override;
|
||||
|
||||
private:
|
||||
IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
|
||||
const HloComputation* hlo_computation,
|
||||
IrEmitterContext* ir_emitter_context);
|
||||
|
||||
// Add a owning Thunk object to the thunk sequence.
|
||||
void AddThunkToThunkSequence(std::unique_ptr<Thunk> thunk) override {
|
||||
thunk_sequence_.emplace_back(std::move(thunk));
|
||||
@ -306,7 +264,8 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
// Builds the prototype of the IR kernel for `inst` and adds it to the module.
|
||||
// This kernel takes as arguments pointers to the given buffer allocations.
|
||||
llvm::Function* BuildKernelPrototype(
|
||||
absl::string_view name, absl::Span<const BufferAllocation* const> args);
|
||||
const HloInstruction& inst,
|
||||
absl::Span<const BufferAllocation* const> args);
|
||||
|
||||
// Helper for writing extra outputs from inside a reduce kernel.
|
||||
Status EmitExtraOutputsForReduce(
|
||||
@ -531,12 +490,6 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
HloComputation* reducer, llvm::Type* element_type,
|
||||
llvm::Value* partial_result_address);
|
||||
|
||||
std::unique_ptr<KernelThunk> BuildKernelThunkFromBufferSlices(
|
||||
absl::string_view name, Thunk::ThunkInfo thunk_info,
|
||||
absl::Span<const BufferSlice* const> slices,
|
||||
std::function<void(const BufferSlice*, llvm::Value*)>
|
||||
bind_slice_to_ir_value);
|
||||
|
||||
// Returns a KernelThunk that invokes the kernel emitted for `inst`. The
|
||||
// caller needs to make sure `inst` outlives the lifetime of the returned
|
||||
// Thunk object. 'implements_whole_instruction' specifies whether this
|
||||
@ -545,11 +498,6 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
std::unique_ptr<KernelThunk> BuildKernelThunk(
|
||||
const HloInstruction* inst, bool implements_whole_instruction);
|
||||
|
||||
std::unique_ptr<KernelThunk> BuildKernelThunkForMlir(
|
||||
absl::string_view name, Thunk::ThunkInfo thunk_info,
|
||||
absl::Span<const MlirBufferSlice> slices,
|
||||
std::vector<llvm_ir::IrArray>* ir_arrays);
|
||||
|
||||
// Returns a thunk that, given a reduce or select-and-scatter op,
|
||||
// initializes its memory to the appropriate initial value.
|
||||
StatusOr<std::unique_ptr<Thunk>> BuildInitializerThunk(
|
||||
@ -557,18 +505,17 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
|
||||
// Returns a WhileThunk that invokes thunk sequences for 'condition' and
|
||||
// 'body' sub-computations of while instruction 'hlo'.
|
||||
StatusOr<std::unique_ptr<Thunk>> BuildWhileThunk(const HloInstruction* hlo);
|
||||
std::unique_ptr<Thunk> BuildWhileThunk(const HloInstruction* hlo);
|
||||
|
||||
// Returns a ForThunk which executes 'loop_limit' invocations of a thunk
|
||||
// sequence from the 'body' sub-computation of the while instruction 'hlo'.
|
||||
StatusOr<std::unique_ptr<Thunk>> BuildForThunk(const HloInstruction* hlo,
|
||||
const int64 loop_limit);
|
||||
std::unique_ptr<Thunk> BuildForThunk(const HloInstruction* hlo,
|
||||
const int64 loop_limit);
|
||||
|
||||
// Returns a ConditionalThunk which executes the thunk sequence for the
|
||||
// 'branch_computation' corresponding to the predicate/branch_index of the
|
||||
// given conditional instruction.
|
||||
StatusOr<std::unique_ptr<Thunk>> BuildConditionalThunk(
|
||||
const HloInstruction* hlo);
|
||||
std::unique_ptr<Thunk> BuildConditionalThunk(const HloInstruction* hlo);
|
||||
|
||||
// Emits current thread id with the given type.
|
||||
//
|
||||
@ -598,9 +545,6 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
absl::optional<int64> thread_id_filter = absl::nullopt,
|
||||
absl::optional<int64> block_id_filter = absl::nullopt);
|
||||
|
||||
StatusOr<const HloComputation*> GetOrCreateSubComputationFromRegion(
|
||||
mlir::Region* region);
|
||||
|
||||
// Returns the last generated thunk.
|
||||
Thunk* LastThunk() const { return thunk_sequence_.back().get(); }
|
||||
|
||||
@ -611,14 +555,6 @@ class IrEmitterUnnested : public IrEmitter,
|
||||
|
||||
// The HloComputation that this IrEmitter emits code for.
|
||||
const HloComputation* hlo_computation_;
|
||||
|
||||
mlir::OwningModuleRef mlir_scratch_module_;
|
||||
|
||||
// This is for cache-purpose only. It has no significant semantics.
|
||||
mlir::LhloDialectEmitter lhlo_scratch_emitter_;
|
||||
|
||||
absl::flat_hash_map<const mlir::Region*, std::unique_ptr<HloModule>>
|
||||
scratch_nested_computations_;
|
||||
};
|
||||
|
||||
} // namespace gpu
|
||||
|
@ -458,35 +458,6 @@ xla_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "sorting_test",
|
||||
srcs = [
|
||||
"sorting_test.cc",
|
||||
],
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"no_rocm",
|
||||
],
|
||||
deps = [
|
||||
":gpu_codegen_test",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:hlo_parser",
|
||||
"//tensorflow/compiler/xla/service/gpu:gpu_executable",
|
||||
"//tensorflow/compiler/xla/tests:filecheck",
|
||||
"//tensorflow/compiler/xla/tests:hlo_test_base",
|
||||
"//tensorflow/compiler/xla/tests:llvm_irgen_test_base",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"@com_google_absl//absl/memory",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_binary(
|
||||
name = "hlo_to_llvm_ir",
|
||||
srcs = ["hlo_to_llvm_ir.cc"],
|
||||
|
@ -8,162 +8,162 @@ compare {
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]])
|
||||
// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]])
|
||||
// CHECK-NEXT: entry:
|
||||
// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
|
||||
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
|
||||
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
|
||||
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]]
|
||||
// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0
|
||||
// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0
|
||||
// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
|
||||
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
|
||||
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
|
||||
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]]
|
||||
// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]])
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
|
||||
// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK: sort.in_bounds-after:
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK: sort.in_bounds-true:
|
||||
// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP8]], 2
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]]
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]]
|
||||
// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]]
|
||||
// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3
|
||||
// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]]
|
||||
// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
|
||||
// CHECK: smaller_comparison_index-after:
|
||||
// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]]
|
||||
// CHECK: smaller_comparison_index-true:
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]]
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]]
|
||||
// CHECK-NEXT: call void @region_0_4(float* [[TMP16]], float* [[TMP17]], i8* [[COMPARE_RETURN_BUFFER]])
|
||||
// CHECK-NEXT: [[TMP18:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
|
||||
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP18]], 0
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: call void @compare(float* [[TMP12]], float* [[TMP13]], i8* [[COMPARE_RETURN_BUFFER]])
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
|
||||
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP14]], 0
|
||||
// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]]
|
||||
// CHECK: is_smaller_than-after:
|
||||
// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]]
|
||||
// CHECK: is_smaller_than-true:
|
||||
// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4
|
||||
// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP17]], align 4
|
||||
// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]]
|
||||
// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4
|
||||
// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]]
|
||||
// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = load float, float* [[TMP13]], align 4
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4
|
||||
// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
|
||||
// CHECK-NEXT: store float [[TMP16]], float* [[TMP18]], align 4
|
||||
// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]]
|
||||
|
||||
// CHECK: define internal void @region_0_4(float* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]])
|
||||
// CHECK: define internal void @compare(float* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]])
|
||||
// CHECK-NEXT: entry:
|
||||
// CHECK-NEXT: [[COMPARE_3_TYPED:%.*]] = alloca i8, align 1
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[ARG_0_1_TYPED:%.*]], align 4
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARG_1_2_TYPED:%.*]], align 4
|
||||
// CHECK-NEXT: [[LT_TYPED:%.*]] = alloca i8, align 1
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[P_0_LHS_TYPED]], align 4
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[P_0_RHS_TYPED]], align 4
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = fcmp olt float [[TMP0]], [[TMP1]]
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i8
|
||||
// CHECK-NEXT: store i8 [[TMP3]], i8* [[COMPARE_3_TYPED]], align 1
|
||||
// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[COMPARE_3_TYPED]], align 1
|
||||
// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG:%.*]], align 1
|
||||
// CHECK-NEXT: store i8 [[TMP3]], i8* [[LT_TYPED]], align 1
|
||||
// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[LT_TYPED]], align 1
|
||||
// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG]], align 1
|
||||
// CHECK-NEXT: ret void
|
||||
|
||||
// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) {
|
||||
// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) {
|
||||
// CHECK-NEXT: entry:
|
||||
// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
|
||||
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
|
||||
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
|
||||
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]]
|
||||
// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0
|
||||
// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0
|
||||
// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
|
||||
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
|
||||
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
|
||||
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]]
|
||||
// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]])
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
|
||||
// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK: sort.in_bounds-after:
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK: sort.in_bounds-true:
|
||||
// CHECK-NEXT: [[TMP11:%.*]] = xor i64 [[TMP8]], 3
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = icmp slt i64 [[TMP8]], [[TMP11]]
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], 3
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = and i1 [[TMP12]], [[TMP13]]
|
||||
// CHECK-NEXT: br i1 [[TMP14]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = xor i64 [[TMP4]], 3
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = icmp slt i64 [[TMP4]], [[TMP7]]
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], 3
|
||||
// CHECK-NEXT: [[TMP10:%.*]] = and i1 [[TMP8]], [[TMP9]]
|
||||
// CHECK-NEXT: br i1 [[TMP10]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
|
||||
// CHECK: smaller_comparison_index-after:
|
||||
// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]]
|
||||
// CHECK: smaller_comparison_index-true:
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]]
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP8]]
|
||||
// CHECK-NEXT: call void @region_0_4(float* [[TMP15]], float* [[TMP16]], i8* [[COMPARE_RETURN_BUFFER]])
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
|
||||
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP17]], 0
|
||||
// CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP4]]
|
||||
// CHECK-NEXT: call void @compare(float* [[TMP11]], float* [[TMP12]], i8* [[COMPARE_RETURN_BUFFER]])
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
|
||||
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP13]], 0
|
||||
// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]]
|
||||
// CHECK: is_smaller_than-after:
|
||||
// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]]
|
||||
// CHECK: is_smaller_than-true:
|
||||
// CHECK-NEXT: [[TMP18:%.*]] = load float, float* [[TMP15]], align 4
|
||||
// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4
|
||||
// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP8]]
|
||||
// CHECK-NEXT: store float [[TMP18]], float* [[TMP20]], align 4
|
||||
// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]]
|
||||
// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = load float, float* [[TMP11]], align 4
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP4]]
|
||||
// CHECK-NEXT: store float [[TMP14]], float* [[TMP16]], align 4
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4
|
||||
// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]]
|
||||
|
||||
// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) {
|
||||
// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) {
|
||||
// CHECK-NEXT: entry:
|
||||
// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
|
||||
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
|
||||
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
|
||||
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]]
|
||||
// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
|
||||
// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0
|
||||
// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
|
||||
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
|
||||
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
|
||||
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]]
|
||||
// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]])
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
|
||||
// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK: sort.in_bounds-after:
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK: sort.in_bounds-true:
|
||||
// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP8]], 2
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]]
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]]
|
||||
// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]]
|
||||
// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3
|
||||
// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]]
|
||||
// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
|
||||
// CHECK: smaller_comparison_index-after:
|
||||
// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]]
|
||||
// CHECK: smaller_comparison_index-true:
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]]
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]]
|
||||
// CHECK-NEXT: call void @region_0_4(float* [[TMP16]], float* [[TMP17]], i8* [[COMPARE_RETURN_BUFFER]])
|
||||
// CHECK-NEXT: [[TMP18:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
|
||||
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP18]], 0
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: call void @compare(float* [[TMP12]], float* [[TMP13]], i8* [[COMPARE_RETURN_BUFFER]])
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
|
||||
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP14]], 0
|
||||
// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]]
|
||||
// CHECK: is_smaller_than-after:
|
||||
// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]]
|
||||
// CHECK: is_smaller_than-true:
|
||||
// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4
|
||||
// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP17]], align 4
|
||||
// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]]
|
||||
// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4
|
||||
// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]]
|
||||
// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = load float, float* [[TMP13]], align 4
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4
|
||||
// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
|
||||
// CHECK-NEXT: store float [[TMP16]], float* [[TMP18]], align 4
|
||||
// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]]
|
||||
ENTRY main {
|
||||
x = f32[2, 3] parameter(0)
|
||||
@ -182,198 +182,210 @@ compare {
|
||||
ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT
|
||||
}
|
||||
|
||||
// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]])
|
||||
// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]])
|
||||
// CHECK-NEXT: entry:
|
||||
// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]*
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]*
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
|
||||
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
|
||||
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
|
||||
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]]
|
||||
// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4]], i64 0
|
||||
// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]*
|
||||
// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0
|
||||
// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]*
|
||||
// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0
|
||||
// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2]], i64 0
|
||||
// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]*
|
||||
// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3]], i64 0
|
||||
// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
|
||||
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
|
||||
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
|
||||
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]]
|
||||
// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]])
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
|
||||
// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2
|
||||
// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK: sort.in_bounds-after:
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK: sort.in_bounds-true:
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = mul i64 [[TMP10]], 2
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = xor i64 [[TMP13]], 1
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = icmp slt i64 [[TMP13]], [[TMP14]]
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = icmp slt i64 [[TMP14]], 3
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = and i1 [[TMP15]], [[TMP16]]
|
||||
// CHECK-NEXT: br i1 [[TMP17]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]]
|
||||
// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3
|
||||
// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]]
|
||||
// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
|
||||
// CHECK: smaller_comparison_index-after:
|
||||
// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]]
|
||||
// CHECK: smaller_comparison_index-true:
|
||||
// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP14]]
|
||||
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
|
||||
// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP14]]
|
||||
// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
|
||||
// CHECK-NEXT: call void @region_0_6(i32* [[TMP18]], i32* [[TMP19]], float* [[TMP20]], float* [[TMP21]], i8* [[COMPARE_RETURN_BUFFER]])
|
||||
// CHECK-NEXT: [[TMP22:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
|
||||
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP22]], 0
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: call void @compare(i32* [[TMP12]], i32* [[TMP13]], float* [[TMP14]], float* [[TMP15]], i8* [[COMPARE_RETURN_BUFFER]])
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
|
||||
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP16]], 0
|
||||
// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]]
|
||||
// CHECK: is_smaller_than-after:
|
||||
// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]]
|
||||
// CHECK: is_smaller_than-true:
|
||||
// CHECK-NEXT: [[TMP23:%.*]] = load i32, i32* [[TMP18]], align 4
|
||||
// CHECK-NEXT: [[TMP24:%.*]] = load i32, i32* [[TMP19]], align 4
|
||||
// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
|
||||
// CHECK-NEXT: store i32 [[TMP23]], i32* [[TMP25]], align 4
|
||||
// CHECK-NEXT: [[TMP26:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP14]]
|
||||
// CHECK-NEXT: store i32 [[TMP24]], i32* [[TMP26]], align 4
|
||||
// CHECK-NEXT: [[TMP27:%.*]] = load float, float* [[TMP20]], align 4
|
||||
// CHECK-NEXT: [[TMP28:%.*]] = load float, float* [[TMP21]], align 4
|
||||
// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
|
||||
// CHECK-NEXT: store float [[TMP27]], float* [[TMP29]], align 4
|
||||
// CHECK-NEXT: [[TMP30:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP14]]
|
||||
// CHECK-NEXT: store float [[TMP28]], float* [[TMP30]], align 4
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = load i32, i32* [[TMP12]], align 4
|
||||
// CHECK-NEXT: [[TMP18:%.*]] = load i32, i32* [[TMP13]], align 4
|
||||
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: store i32 [[TMP17]], i32* [[TMP19]], align 4
|
||||
// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
|
||||
// CHECK-NEXT: store i32 [[TMP18]], i32* [[TMP20]], align 4
|
||||
// CHECK-NEXT: [[TMP21:%.*]] = load float, float* [[TMP14]], align 4
|
||||
// CHECK-NEXT: [[TMP22:%.*]] = load float, float* [[TMP15]], align 4
|
||||
// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: store float [[TMP21]], float* [[TMP23]], align 4
|
||||
// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP8]]
|
||||
// CHECK-NEXT: store float [[TMP22]], float* [[TMP24]], align 4
|
||||
// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]]
|
||||
|
||||
// CHECK: define internal void @region_0_6(i32* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], i32* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]])
|
||||
// CHECK: define internal void @compare(i32* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], i32* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]])
|
||||
// CHECK-NEXT: entry:
|
||||
// CHECK-NEXT: [[COMPARE_5_TYPED:%.*]] = alloca i8, align 1
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[ARG_2_3_TYPED:%.*]], align 4
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARG_3_4_TYPED:%.*]], align 4
|
||||
// CHECK-NEXT: [[LT_TYPED:%.*]] = alloca i8, align 1
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[P_1_LHS_TYPED]], align 4
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[P_1_RHS_TYPED]], align 4
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = fcmp olt float [[TMP0]], [[TMP1]]
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i8
|
||||
// CHECK-NEXT: store i8 [[TMP3]], i8* [[COMPARE_5_TYPED]], align 1
|
||||
// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[COMPARE_5_TYPED]], align 1
|
||||
// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG:%.*]], align 1
|
||||
// CHECK-NEXT: store i8 [[TMP3]], i8* [[LT_TYPED]], align 1
|
||||
// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[LT_TYPED]], align 1
|
||||
// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG]], align 1
|
||||
// CHECK-NEXT: ret void
|
||||
|
||||
// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]])
|
||||
// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]])
|
||||
// CHECK-NEXT: entry:
|
||||
// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]*
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]*
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
|
||||
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
|
||||
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
|
||||
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]]
|
||||
// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0
|
||||
// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]*
|
||||
// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
|
||||
// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]*
|
||||
// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0
|
||||
// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2:%.*]], i64 0
|
||||
// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]*
|
||||
// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3:%.*]], i64 0
|
||||
// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
|
||||
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
|
||||
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
|
||||
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]]
|
||||
// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]])
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
|
||||
// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2
|
||||
// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK: sort.in_bounds-after:
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK: sort.in_bounds-true:
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = xor i64 [[TMP10]], 3
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP10]], [[TMP13]]
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = icmp slt i64 [[TMP13]], 3
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = and i1 [[TMP14]], [[TMP15]]
|
||||
// CHECK-NEXT: br i1 [[TMP16]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = xor i64 [[TMP4]], 3
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = icmp slt i64 [[TMP4]], [[TMP7]]
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], 3
|
||||
// CHECK-NEXT: [[TMP10:%.*]] = and i1 [[TMP8]], [[TMP9]]
|
||||
// CHECK-NEXT: br i1 [[TMP10]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
|
||||
// CHECK: smaller_comparison_index-after:
|
||||
// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]]
|
||||
// CHECK: smaller_comparison_index-true:
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
|
||||
// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP10]]
|
||||
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
|
||||
// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP10]]
|
||||
// CHECK-NEXT: call void @region_0_6(i32* [[TMP17]], i32* [[TMP18]], float* [[TMP19]], float* [[TMP20]], i8* [[COMPARE_RETURN_BUFFER]])
|
||||
// CHECK-NEXT: [[TMP21:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
|
||||
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP21]], 0
|
||||
// CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP4]]
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP4]]
|
||||
// CHECK-NEXT: call void @compare(i32* [[TMP11]], i32* [[TMP12]], float* [[TMP13]], float* [[TMP14]], i8* [[COMPARE_RETURN_BUFFER]])
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
|
||||
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP15]], 0
|
||||
// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]]
|
||||
// CHECK: is_smaller_than-after:
|
||||
// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]]
|
||||
// CHECK: is_smaller_than-true:
|
||||
// CHECK-NEXT: [[TMP22:%.*]] = load i32, i32* [[TMP17]], align 4
|
||||
// CHECK-NEXT: [[TMP23:%.*]] = load i32, i32* [[TMP18]], align 4
|
||||
// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP10]]
|
||||
// CHECK-NEXT: store i32 [[TMP22]], i32* [[TMP24]], align 4
|
||||
// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
|
||||
// CHECK-NEXT: store i32 [[TMP23]], i32* [[TMP25]], align 4
|
||||
// CHECK-NEXT: [[TMP26:%.*]] = load float, float* [[TMP19]], align 4
|
||||
// CHECK-NEXT: [[TMP27:%.*]] = load float, float* [[TMP20]], align 4
|
||||
// CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP10]]
|
||||
// CHECK-NEXT: store float [[TMP26]], float* [[TMP28]], align 4
|
||||
// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]]
|
||||
// CHECK-NEXT: store float [[TMP27]], float* [[TMP29]], align 4
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = load i32, i32* [[TMP11]], align 4
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = load i32, i32* [[TMP12]], align 4
|
||||
// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP4]]
|
||||
// CHECK-NEXT: store i32 [[TMP16]], i32* [[TMP18]], align 4
|
||||
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: store i32 [[TMP17]], i32* [[TMP19]], align 4
|
||||
// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP13]], align 4
|
||||
// CHECK-NEXT: [[TMP21:%.*]] = load float, float* [[TMP14]], align 4
|
||||
// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP4]]
|
||||
// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4
|
||||
// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]]
|
||||
// CHECK-NEXT: store float [[TMP21]], float* [[TMP23]], align 4
|
||||
// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]]
|
||||
|
||||
// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]])
|
||||
// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]])
|
||||
// CHECK-NEXT: entry:
|
||||
// CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]*
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]*
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
|
||||
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
|
||||
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
|
||||
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]]
|
||||
// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0
|
||||
// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]*
|
||||
// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0
|
||||
// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]*
|
||||
// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0
|
||||
// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2:%.*]], i64 0
|
||||
// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]*
|
||||
// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3:%.*]], i64 0
|
||||
// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]*
|
||||
// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6
|
||||
// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64
|
||||
// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7
|
||||
// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64
|
||||
// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4
|
||||
// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]]
|
||||
// CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]])
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
|
||||
// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2
|
||||
// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1
|
||||
// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2
|
||||
// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2
|
||||
// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4
|
||||
// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]]
|
||||
// CHECK: sort.in_bounds-after:
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = bitcast [2 x [3 x i32]]* [[TMP1]] to i8*
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 0
|
||||
// CHECK-NEXT: store i8* [[TMP13]], i8** [[TMP14]], align 8
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = bitcast [2 x [3 x float]]* [[TMP3]] to i8*
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 1
|
||||
// CHECK-NEXT: store i8* [[TMP15]], i8** [[TMP16]], align 8
|
||||
// CHECK-NEXT: [[TMP7:%.*]] = bitcast [2 x [3 x i32]]* [[SORT_TYPED2]] to i8*
|
||||
// CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[SORT_TYPED]], i64 0, i64 0
|
||||
// CHECK-NEXT: store i8* [[TMP7]], i8** [[TMP8]], align 8
|
||||
// CHECK-NEXT: [[TMP9:%.*]] = bitcast [2 x [3 x float]]* [[SORT_TYPED4]] to i8*
|
||||
// CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[SORT_TYPED]], i64 0, i64 1
|
||||
// CHECK-NEXT: store i8* [[TMP9]], i8** [[TMP10]], align 8
|
||||
// CHECK-NEXT: ret void
|
||||
// CHECK: sort.in_bounds-true:
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = mul i64 [[TMP10]], 2
|
||||
// CHECK-NEXT: [[TMP18:%.*]] = xor i64 [[TMP17]], 1
|
||||
// CHECK-NEXT: [[TMP19:%.*]] = icmp slt i64 [[TMP17]], [[TMP18]]
|
||||
// CHECK-NEXT: [[TMP20:%.*]] = icmp slt i64 [[TMP18]], 3
|
||||
// CHECK-NEXT: [[TMP21:%.*]] = and i1 [[TMP19]], [[TMP20]]
|
||||
// CHECK-NEXT: br i1 [[TMP21]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
|
||||
// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP4]], 2
|
||||
// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1
|
||||
// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]]
|
||||
// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3
|
||||
// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]]
|
||||
// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]]
|
||||
// CHECK: smaller_comparison_index-after:
|
||||
// CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]]
|
||||
// CHECK: smaller_comparison_index-true:
|
||||
// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP18]]
|
||||
// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP17]]
|
||||
// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP18]]
|
||||
// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP17]]
|
||||
// CHECK-NEXT: call void @region_0_6(i32* [[TMP22]], i32* [[TMP23]], float* [[TMP24]], float* [[TMP25]], i8* [[COMPARE_RETURN_BUFFER]])
|
||||
// CHECK-NEXT: [[TMP26:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
|
||||
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP26]], 0
|
||||
// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP12]]
|
||||
// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP11]]
|
||||
// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP12]]
|
||||
// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP11]]
|
||||
// CHECK-NEXT: call void @compare(i32* [[TMP16]], i32* [[TMP17]], float* [[TMP18]], float* [[TMP19]], i8* [[COMPARE_RETURN_BUFFER]])
|
||||
// CHECK-NEXT: [[TMP20:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1
|
||||
// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP20]], 0
|
||||
// CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]]
|
||||
// CHECK: is_smaller_than-after:
|
||||
// CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]]
|
||||
// CHECK: is_smaller_than-true:
|
||||
// CHECK-NEXT: [[TMP27:%.*]] = load i32, i32* [[TMP22]], align 4
|
||||
// CHECK-NEXT: [[TMP28:%.*]] = load i32, i32* [[TMP23]], align 4
|
||||
// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP17]]
|
||||
// CHECK-NEXT: store i32 [[TMP27]], i32* [[TMP29]], align 4
|
||||
// CHECK-NEXT: [[TMP30:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP18]]
|
||||
// CHECK-NEXT: store i32 [[TMP28]], i32* [[TMP30]], align 4
|
||||
// CHECK-NEXT: [[TMP31:%.*]] = load float, float* [[TMP24]], align 4
|
||||
// CHECK-NEXT: [[TMP32:%.*]] = load float, float* [[TMP25]], align 4
|
||||
// CHECK-NEXT: [[TMP33:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP17]]
|
||||
// CHECK-NEXT: store float [[TMP31]], float* [[TMP33]], align 4
|
||||
// CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP18]]
|
||||
// CHECK-NEXT: store float [[TMP32]], float* [[TMP34]], align 4
|
||||
// CHECK-NEXT: [[TMP21:%.*]] = load i32, i32* [[TMP16]], align 4
|
||||
// CHECK-NEXT: [[TMP22:%.*]] = load i32, i32* [[TMP17]], align 4
|
||||
// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP11]]
|
||||
// CHECK-NEXT: store i32 [[TMP21]], i32* [[TMP23]], align 4
|
||||
// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP12]]
|
||||
// CHECK-NEXT: store i32 [[TMP22]], i32* [[TMP24]], align 4
|
||||
// CHECK-NEXT: [[TMP25:%.*]] = load float, float* [[TMP18]], align 4
|
||||
// CHECK-NEXT: [[TMP26:%.*]] = load float, float* [[TMP19]], align 4
|
||||
// CHECK-NEXT: [[TMP27:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP11]]
|
||||
// CHECK-NEXT: store float [[TMP25]], float* [[TMP27]], align 4
|
||||
// CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP12]]
|
||||
// CHECK-NEXT: store float [[TMP26]], float* [[TMP28]], align 4
|
||||
// CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]]
|
||||
ENTRY main {
|
||||
x = s32[2, 3] parameter(0)
|
||||
|
@ -1,71 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_parser.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/tests/filecheck.h"
|
||||
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
|
||||
#include "tensorflow/compiler/xla/xla.pb.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
||||
namespace {
|
||||
|
||||
class SortingTest : public GpuCodegenTest {
|
||||
protected:
|
||||
HloModuleConfig ConfigWithoutLayoutAssignment() {
|
||||
HloModuleConfig config;
|
||||
auto debug_options = HloTestBase::GetDebugOptionsForTest();
|
||||
// Disable layout_assignment to use the preassigned layouts.
|
||||
debug_options.add_xla_disable_hlo_passes("layout-assignment");
|
||||
config.set_debug_options(debug_options);
|
||||
return config;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_F(SortingTest, Regression1) {
|
||||
const char* hlo_text = R"(
|
||||
HloModule TestModule
|
||||
|
||||
compare {
|
||||
p.0.lhs = f32[] parameter(0)
|
||||
p.0.rhs = f32[] parameter(1)
|
||||
ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
|
||||
}
|
||||
|
||||
ENTRY TestComputation {
|
||||
x = f32[3, 2]{1, 0} parameter(0)
|
||||
x.copy = f32[3, 2]{0, 1} copy(x)
|
||||
ROOT sort = f32[3, 2]{0, 1} sort(x.copy), dimensions={1}, to_apply=compare
|
||||
}
|
||||
|
||||
)";
|
||||
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, ErrorSpec{1e-5, 1e-5}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
@ -415,10 +415,9 @@ llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper,
|
||||
return inst;
|
||||
}
|
||||
|
||||
string IrName(absl::string_view a) {
|
||||
std::string s(a);
|
||||
s.erase(std::remove(s.begin(), s.end(), '%'), s.end());
|
||||
return s;
|
||||
string IrName(string a) {
|
||||
a.erase(std::remove(a.begin(), a.end(), '%'), a.end());
|
||||
return a;
|
||||
}
|
||||
|
||||
string IrName(absl::string_view a, absl::string_view b) {
|
||||
|
@ -87,7 +87,7 @@ string DumpModuleToString(const llvm::Module& module);
|
||||
// - joining all of the nonempty inputs by '.', and then
|
||||
// - removing all '%'s.
|
||||
//
|
||||
string IrName(absl::string_view a);
|
||||
string IrName(string a);
|
||||
string IrName(absl::string_view a, absl::string_view b);
|
||||
string IrName(const HloInstruction* a, absl::string_view b = "");
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user