From 0572b205b847917062091e4377110f5431c60d2b Mon Sep 17 00:00:00 2001 From: Sanjoy Das Date: Tue, 11 Aug 2020 12:08:22 -0700 Subject: [PATCH] Roll back XLA/GPU LHLO sort emitter again It breaks an internal msan enabled test. PiperOrigin-RevId: 326072372 Change-Id: I245525cefa4da88097725662c75ccb213a328f19 --- tensorflow/compiler/mlir/xla/hlo_utils.cc | 3 - .../non_identity_layouts.hlotxt | 2 +- .../xla/transforms/mhlo_to_lhlo_with_xla.cc | 11 +- .../xla/transforms/mhlo_to_lhlo_with_xla.h | 3 +- tensorflow/compiler/xla/service/gpu/BUILD | 10 - .../compiler/xla/service/gpu/gpu_compiler.cc | 24 +- .../xla/service/gpu/hlo_to_ir_bindings.cc | 20 +- .../xla/service/gpu/hlo_to_ir_bindings.h | 4 - .../xla/service/gpu/ir_emitter_context.h | 7 +- .../xla/service/gpu/ir_emitter_unnested.cc | 416 ++++----------- .../xla/service/gpu/ir_emitter_unnested.h | 82 +-- .../compiler/xla/service/gpu/tests/BUILD | 29 - .../xla/service/gpu/tests/sorting.hlo | 504 +++++++++--------- .../xla/service/gpu/tests/sorting_test.cc | 71 --- .../compiler/xla/service/llvm_ir/llvm_util.cc | 7 +- .../compiler/xla/service/llvm_ir/llvm_util.h | 2 +- 16 files changed, 403 insertions(+), 792 deletions(-) delete mode 100644 tensorflow/compiler/xla/service/gpu/tests/sorting_test.cc diff --git a/tensorflow/compiler/mlir/xla/hlo_utils.cc b/tensorflow/compiler/mlir/xla/hlo_utils.cc index 18b4265d786..cf78c81908d 100644 --- a/tensorflow/compiler/mlir/xla/hlo_utils.cc +++ b/tensorflow/compiler/mlir/xla/hlo_utils.cc @@ -83,9 +83,6 @@ StatusOr> GetPermutationIfAvailable( strides[dim] = accumulated_stride; accumulated_stride *= shape.dimensions(dim); } - if (accumulated_stride == 0) { - return llvm::SmallVector{}; - } return llvm::SmallVector{ makeStridedLinearLayoutMap(strides, /*offset=*/0, builder.getContext())}; } diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt index a83e36cff64..3630d2d45e4 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/non_identity_layouts.hlotxt @@ -8,6 +8,6 @@ HloModule TestModule ENTRY TestComputation { x = f32[3, 2]{1,0} parameter(0) - // CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) {name = "copy.1"} : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> () + // CHECK: "lmhlo.copy"(%{{.*}}, %{{.*}}) : (memref<3x2xf32>, memref<3x2xf32, #[[MAP]]>) -> () ROOT x.copy = f32[3, 2]{0,1} copy(x) } diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index 6ce91599fb1..832bad2dcc8 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassOptions.h" // from @llvm-project #include "mlir/Translation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" @@ -181,10 +182,7 @@ template StatusOr LhloDialectEmitter::CreateOpWithoutAttrs( HloInstruction* instr) { Location loc = getLocation(instr); - std::pair attrs[] = { - {Identifier::get("name", builder_.getContext()), - builder_.getStringAttr(instr->name())}, - }; + ArrayRef> attrs; ArrayRef rets{}; llvm::SmallVector operands; @@ -254,14 +252,15 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { return Status::OK(); } -StatusOr LhloDialectEmitter::EmitSortOp(HloInstruction* instr) { +StatusOr LhloDialectEmitter::EmitSortOp( + HloInstruction* instr) { TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs(instr)); auto* sort_instr = ::xla::Cast<::xla::HloSortInstruction>(instr); sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension())); sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable())); TF_RETURN_IF_ERROR(::xla::HloFunctionImporter::ImportAsRegion( *sort_instr->called_computations()[0], &sort.comparator(), &builder_)); - return sort; + return sort.getOperation(); } Status LhloDialectEmitter::HandleSort(HloInstruction* instr) { diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index 4000fa01970..bdc977616b1 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -19,7 +19,6 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -42,7 +41,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { builder_(module.getContext()), i8_type_(builder_.getIntegerType(8)) {} - ::xla::StatusOr EmitSortOp(::xla::HloInstruction* instr); + ::xla::StatusOr EmitSortOp(::xla::HloInstruction* instr); private: template diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index a19f9965fc7..074fbd92b27 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -254,11 +254,6 @@ cc_library( ":target_util", ":thunk", ":thunk_emitter", - "//tensorflow/compiler/mlir/hlo:lhlo", - "//tensorflow/compiler/mlir/xla:hlo_utils", - "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", - "//tensorflow/compiler/mlir/xla:mlir_hlo_to_hlo", - "//tensorflow/compiler/mlir/xla:type_to_shape", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -296,8 +291,6 @@ cc_library( "@com_google_absl//absl/types:span", "@llvm-project//llvm:Core", "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:StandardOps", ], ) @@ -1166,7 +1159,6 @@ cc_library( ":target_constants", ":tree_reduction_rewriter", ":variadic_op_splitter", - "//tensorflow/compiler/mlir/xla:mhlo_to_lhlo_with_xla", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -1225,8 +1217,6 @@ cc_library( "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@llvm-project//llvm:Core", - "@llvm-project//mlir:AllPassesAndDialectsNoRegistration", - "@llvm-project//mlir:IR", ], ) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b796737e601..f5bf7476059 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -29,8 +29,6 @@ limitations under the License. #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Verifier.h" -#include "mlir/IR/Module.h" // from @llvm-project -#include "mlir/InitAllDialects.h" // from @llvm-project #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" #include "tensorflow/compiler/xla/service/all_reduce_combiner.h" @@ -518,22 +516,15 @@ static Status CompileModuleToLlvmIrImpl( DumpHloModuleIfEnabled(*hlo_module, **buffer_assignment, "after_optimizations"); - mlir::registerAllDialects(); - mlir::MLIRContext mlir_context; - IrEmitterContext ir_emitter_context( hlo_module, buffer_assignment->get(), platform_name, gpu_device_info, - cuda_compute_capability, profile_index_map, &mlir_context, - llvm_module->get()); + cuda_compute_capability, profile_index_map, llvm_module->get()); HloComputation* entry_computation = hlo_module->entry_computation(); + IrEmitterUnnested ir_emitter(hlo_module->config(), entry_computation, + &ir_emitter_context); - TF_ASSIGN_OR_RETURN( - auto ir_emitter, - IrEmitterUnnested::Create(hlo_module->config(), entry_computation, - &ir_emitter_context)); - - TF_RETURN_IF_ERROR(ir_emitter->EmitConstantGlobals()); + TF_RETURN_IF_ERROR(ir_emitter.EmitConstantGlobals()); { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); @@ -542,10 +533,9 @@ static Status CompileModuleToLlvmIrImpl( ThunkSequence thunk_sequence; absl::Span order = hlo_schedule->ThunkLaunchOrder(); for (HloInstruction* instruction : order) { - TF_RETURN_IF_ERROR(instruction->Visit(ir_emitter.get())); - TF_RETURN_IF_ERROR(ir_emitter->Postprocess(instruction)); - std::unique_ptr thunks = - ir_emitter->ConsumeThunkSequence(); + TF_RETURN_IF_ERROR(instruction->Visit(&ir_emitter)); + TF_RETURN_IF_ERROR(ir_emitter.Postprocess(instruction)); + std::unique_ptr thunks = ir_emitter.ConsumeThunkSequence(); // The invariants between each input HloInstruction* and output Thunk* are // not all explicitly checked, but at least we can document them here: diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 332db83b6ad..5d38d1b727c 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -117,11 +117,11 @@ static bool HasMeaningfulName(llvm::Value* value) { return false; } -llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value, - llvm::IRBuilder<>* b) { - llvm::Type* pointee_type = - llvm_ir::ShapeToIrType(shape, b->GetInsertBlock()->getModule()); - +llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, + ShapeIndexView shape_index, + llvm::Value* ir_value) { + llvm::Type* pointee_type = llvm_ir::ShapeToIrType( + ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_); llvm::Type* dest_type = pointee_type->getPointerTo(); llvm::Value* typed_ir_value; @@ -129,17 +129,9 @@ llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value, typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast( llvm::cast(ir_value), dest_type); } else { - typed_ir_value = b->CreatePointerBitCastOrAddrSpaceCast( + typed_ir_value = b_->CreatePointerBitCastOrAddrSpaceCast( ir_value, pointee_type->getPointerTo()); } - return typed_ir_value; -} - -llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, - ShapeIndexView shape_index, - llvm::Value* ir_value) { - auto typed_ir_value = CastToTypedValue( - ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_value, b_); if (!HasMeaningfulName(ir_value)) { ir_value->setName(llvm_ir::IrName(&hlo, "raw")); } diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 3813ec6c949..5eef6727801 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -116,10 +116,6 @@ class HloToIrBindings { llvm::Value* temp_buffer_base_ = nullptr; }; -// Converts `ir_value` with type i8* to a typed LLVM Value* based on `shape`. -llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value, - llvm::IRBuilder<>* b); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h index 7d5a8d032e6..9c43f80dc60 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_context.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_CONTEXT_H_ #include "llvm/IR/Module.h" -#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h" #include "tensorflow/compiler/xla/service/hlo_execution_profile.h" @@ -35,15 +34,13 @@ class IrEmitterContext { const HloModule* hlo_module, const BufferAssignment* buffer_assignment, std::string platform_name, GpuDeviceInfo gpu_device_info, absl::optional cuda_compute_capability, - const HloProfileIndexMap* profile_index_map, - mlir::MLIRContext* mlir_context, llvm::Module* llvm_module) + const HloProfileIndexMap* profile_index_map, llvm::Module* llvm_module) : hlo_module_(hlo_module), buffer_assignment_(buffer_assignment), platform_name_(std::move(platform_name)), gpu_device_info_(gpu_device_info), cuda_compute_capability_(cuda_compute_capability), profile_index_map_(profile_index_map), - mlir_context_(mlir_context), llvm_module_(llvm_module) {} // Disallow copy and assign. IrEmitterContext(const IrEmitterContext&) = delete; @@ -60,7 +57,6 @@ class IrEmitterContext { return cuda_compute_capability_; } const HloProfileIndexMap* profile_index_map() { return profile_index_map_; } - mlir::MLIRContext* mlir_context() { return mlir_context_; } llvm::Module* llvm_module() { return llvm_module_; } NameUniquer* name_uniquer() { return &name_uniquer_; } @@ -71,7 +67,6 @@ class IrEmitterContext { GpuDeviceInfo gpu_device_info_; absl::optional cuda_compute_capability_; const HloProfileIndexMap* profile_index_map_; - mlir::MLIRContext* mlir_context_; llvm::Module* llvm_module_; NameUniquer name_uniquer_; }; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f88c70b1a33..61b78b6004d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -37,13 +37,6 @@ limitations under the License. #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/Function.h" // from @llvm-project -#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" -#include "tensorflow/compiler/mlir/xla/hlo_utils.h" -#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" -#include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" @@ -151,86 +144,13 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } -const BufferAllocation* GetAllocation( - mlir::BlockArgument func_arg, const BufferAssignment& buffer_assignment) { - auto func_op = - mlir::cast(func_arg.getParentRegion()->getParentOp()); - int64 allocation_index = func_op - .getArgAttrOfType( - func_arg.getArgNumber(), "lmhlo.alloc") - .getValue() - .getSExtValue(); - return &buffer_assignment.GetAllocation(allocation_index); -} - -StatusOr GetAllocationSliceForMlir( - mlir::Value v, const BufferAssignment& buffer_assignment) { - int64 size = v.getType().cast().getSizeInBits() / 8; - - if (auto arg = v.dyn_cast()) { - return BufferAllocation::Slice(GetAllocation(arg, buffer_assignment), 0, - size); - } - - // We match two patterns here: - // * v = ViewOp(arg); - // * v = StaticMemRefCastOp(ViewOp(arg)); - if (mlir::Operation* op = v.getDefiningOp()) { - if (auto cast = mlir::dyn_cast(op)) { - mlir::Value source = cast.getViewSource(); - op = source.getDefiningOp(); - if (!op) { - return Unimplemented("StaticMemRefCastOp has to wrap an op"); - } - } - if (auto view = mlir::dyn_cast(op)) { - return BufferAllocation::Slice( - GetAllocation(view.source().cast(), - buffer_assignment), - mlir::cast(view.byte_shift().getDefiningOp()) - .value() - .cast() - .getValue() - .getSExtValue(), - size); - } - return Unimplemented("StaticMemRefCastOp has to wrap a ViewOp"); - } - - return Unimplemented( - "Operand has to be in the form of ViewOp(arg) or " - "StaticMemRefCastOp(ViewOp(arg))"); -} - -absl::string_view GetHloName(mlir::Operation* op) { - if (auto attr = op->getAttrOfType("name")) { - auto ref = attr.getValue(); - return absl::string_view(ref.data(), ref.size()); - } - return ""; -} - } // namespace IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, IrEmitterContext* ir_emitter_context) : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false), - hlo_computation_(hlo_computation), - mlir_scratch_module_(mlir::ModuleOp::create( - mlir::Builder(ir_emitter_context->mlir_context()).getUnknownLoc())), - lhlo_scratch_emitter_(ir_emitter_context_->buffer_assignment(), - *hlo_computation, mlir_scratch_module_.get()) {} - -StatusOr> IrEmitterUnnested::Create( - const HloModuleConfig& hlo_module_config, - const HloComputation* hlo_computation, - IrEmitterContext* ir_emitter_context) { - auto emitter = std::unique_ptr(new IrEmitterUnnested( - hlo_module_config, hlo_computation, ir_emitter_context)); - TF_RETURN_IF_ERROR(emitter->lhlo_scratch_emitter_.Initialize()); - return std::move(emitter); -} + hlo_computation_(hlo_computation) {} Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { bindings_.UnbindAllLocalIrValues(); @@ -238,11 +158,12 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { } llvm::Function* IrEmitterUnnested::BuildKernelPrototype( - absl::string_view name, absl::Span args) { + const HloInstruction& inst, + absl::Span args) { // Compute the kernel name. The opcode string may contain "-" which cannot be // in a PTX function name, so sanitize the name before uniquifying it. string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( - llvm_ir::SanitizeFunctionName(std::string(name))); + llvm_ir::SanitizeFunctionName(inst.name())); // Create the kernel and add it to the module. llvm::Module* module = ir_emitter_context_->llvm_module(); @@ -438,8 +359,7 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { } Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { - TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional)); - AddThunkToThunkSequence(std::move(thunk)); + AddThunkToThunkSequence(BuildConditionalThunk(conditional)); return Status::OK(); } @@ -1118,13 +1038,10 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { // Build ForThunk for conformant while loops, otherwise build WhileThunk. auto config = xla_while->backend_config(); if (config.ok() && config.ValueOrDie().has_known_trip_count()) { - TF_ASSIGN_OR_RETURN( - auto thunk, + AddThunkToThunkSequence( BuildForThunk(xla_while, config.ValueOrDie().known_trip_count().n())); - AddThunkToThunkSequence(std::move(thunk)); } else { - TF_ASSIGN_OR_RETURN(auto thunk, BuildWhileThunk(xla_while)); - AddThunkToThunkSequence(std::move(thunk)); + AddThunkToThunkSequence(BuildWhileThunk(xla_while)); } return Status::OK(); } @@ -1347,109 +1264,39 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { return IrEmitter::HandleSelect(select); } -StatusOr -IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region) { - std::unique_ptr& module = scratch_nested_computations_[region]; - if (module == nullptr) { - xla::XlaComputation xla_computation; - TF_RETURN_IF_ERROR(ConvertRegionToComputation(region, &xla_computation)); - TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape()); - TF_ASSIGN_OR_RETURN( - module, HloModule::CreateFromProto(xla_computation.proto(), - HloModuleConfig(program_shape))); - } - return module->entry_computation(); -} - Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { - MlirEmitterInput result; - - TF_ASSIGN_OR_RETURN(auto sort_op, lhlo_scratch_emitter_.EmitSortOp(sort)); - result.op = sort_op; - result.name = GetHloName(sort_op); - // The name in sort op has no semantics, and it's for debug only. If the name - // doesn't exist, we should use a namer (e.g. count-based). - // TODO(timshen): use a namer instead of relying on the HloInstruction names. - if (result.name.empty()) { - result.name = sort->name(); - } - const auto& buffer_assignment = ir_emitter_context_->buffer_assignment(); - auto& slice = result.extra_slice; - TF_ASSIGN_OR_RETURN(slice.buffer_slice, - buffer_assignment.GetUniqueSlice(sort, {})); - slice.written = true; - slice.shape = sort->shape(); - - result.thunk_info = GetThunkInfo(sort); - - return EmitMlirSort(result); -} - -Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) { - const auto& buffer_assignment = ir_emitter_context_->buffer_assignment(); - auto sort_op = mlir::cast(input.op); - - int operand_count = sort_op.operands().size(); - std::vector operand_shapes(operand_count); - std::vector slices; - std::vector output_shapes(sort_op.output().size()); - - for (int i = 0; i < operand_count; i++) { - operand_shapes[i] = - TypeToShape(sort_op.operands()[i].getType().cast()); - } - - // Craft n + 1 slices, where the first n are output parameters, and the last - // is the on-device tuple storage. We don't need n operands because sorting - // kernels are always in-place. - for (int i = 0; i < operand_count; i++) { - output_shapes[i] = - TypeToShape(sort_op.output()[i].getType().cast()); - MlirBufferSlice slice; - TF_ASSIGN_OR_RETURN( - slice.buffer_slice, - GetAllocationSliceForMlir(sort_op.output()[i], buffer_assignment)); - slice.written = true; - slice.shape = operand_shapes[i]; - slices.push_back(slice); - } - slices.push_back(input.extra_slice); - std::vector> thunks; - - Shape keys_shape = operand_shapes[0]; - int64 dimension_to_sort = sort_op.dimension().getSExtValue(); - for (int64 i = 0; i < operand_count; ++i) { + Shape keys_shape = sort->operand(0)->shape(); + int64 dimension_to_sort = sort->dimensions(0); + for (int64 i = 0; i < sort->operand_count(); ++i) { + ShapeIndex shape_index = + sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); // We assume that the layout of all involved operands and outputs is the // same. - TF_RET_CHECK( - LayoutUtil::LayoutsInShapesEqual(keys_shape, operand_shapes[i])); - TF_RET_CHECK( - LayoutUtil::LayoutsInShapesEqual(keys_shape, output_shapes[i])); + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape, + sort->operand(i)->shape())); + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index))); // If possible, we share buffers. If that is not possible, we need to copy // the values, because the emitter does the sorting in-place. - TF_ASSIGN_OR_RETURN( - auto destination_buffer, - GetAllocationSliceForMlir(sort_op.output()[i], buffer_assignment)); - TF_ASSIGN_OR_RETURN( - auto source_address, - GetAllocationSliceForMlir(sort_op.operands()[i], buffer_assignment)); + auto destination_buffer = GetAllocationSlice(*sort, shape_index); + auto source_address = GetAllocationSlice(*sort->operand(i)); if (destination_buffer != source_address) { // TODO(b/26783907): Figure out why we never seem to share buffers for // key/value sort. - VLOG(2) << input.name << " requires initial D2D copy for operand " << i; + VLOG(2) << sort->name() << " requires initial D2D copy for operand " << i; thunks.push_back(absl::make_unique( Thunk::ThunkInfo(), /*source_address=*/source_address, /*destination_buffer=*/destination_buffer, - /*mem_size=*/ShapeUtil::ByteSizeOf(operand_shapes[i]))); + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()))); } } uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound); - VLOG(2) << input.name << " requires " << num_stages << " stages."; + VLOG(2) << sort->name() << " requires " << num_stages << " stages."; CHECK_GE(1ULL << num_stages, dimension_to_sort_bound); CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound); @@ -1513,10 +1360,10 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) { // we have not enough threads, or not enough shared memory. Also it does not // give a speedup if the tile size is < 128. int64 total_shared_memory_needed = 0; - for (int64 i = 0; i < operand_count; ++i) { + for (int64 i = 0; i < sort->operand_count(); ++i) { total_shared_memory_needed += - kTileSize * - ShapeUtil::ByteSizeOfPrimitiveType(operand_shapes[i].element_type()); + kTileSize * ShapeUtil::ByteSizeOfPrimitiveType( + sort->operand(i)->shape().element_type()); } bool no_tiling = kTileSize < 128 || @@ -1529,7 +1376,7 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) { "kTileSize=%d < 128, " "kThreadsPerBlock=%d > threads_per_block_limit=%d, " "total_shared_memory_needed=%d > shared_memory_per_block=%d", - input.name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock, + sort->name(), (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock, ir_emitter_context_->gpu_device_info().threads_per_block_limit, total_shared_memory_needed, ir_emitter_context_->gpu_device_info().shared_memory_per_block); @@ -1537,38 +1384,37 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) { uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock); LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock); VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block", - input.name, num_blocks, kThreadsPerBlock); + sort->name(), num_blocks, kThreadsPerBlock); - std::vector ir_arrays; auto emit_kernel = [&](absl::Span xor_masks) { VLOG(2) << absl::StreamFormat( - "%s uses kernel for xor masks [%s]", input.name, + "%s uses kernel for xor masks [%s]", sort->name(), absl::StrJoin(xor_masks, ", ", [](std::string* out, int64 xor_mask) { absl::StrAppendFormat(out, "0x%x", xor_mask); })); - thunks.push_back(BuildKernelThunkForMlir(input.name, Thunk::ThunkInfo(), - slices, &ir_arrays)); + thunks.push_back( + BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); LaunchDimensions launch_dimensions = xor_masks.size() > 1 ? tiled_launch_dimensions : standard_launch_dimensions; UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); std::vector values_arrays; - values_arrays.reserve(operand_count); - for (int64 i = 0; i < operand_count; ++i) { - values_arrays.push_back(ir_arrays[i]); + values_arrays.reserve(sort->operand_count()); + for (int64 i = 0; i < sort->operand_count(); ++i) { + ShapeIndex shape_index = + sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({}); + values_arrays.push_back(GetIrArray(*sort, *sort, shape_index)); } - TF_ASSIGN_OR_RETURN( - const HloComputation* comparator, - GetOrCreateSubComputationFromRegion(&sort_op.comparator())); return llvm_ir::EmitSortInPlace( - dimension_to_sort, values_arrays, IrName(input.name), xor_masks, &b_, + dimension_to_sort, values_arrays, IrName(sort), xor_masks, &b_, launch_dimensions, xor_masks.size() > 1 ? num_iterations_in_sort_dim : standard_num_iterations_in_sort_dim, kTileSize, [&](absl::Span operands, llvm::Value* output) { - return EmitCallToNestedComputation(*comparator, operands, output); + return EmitCallToNestedComputation(*sort->to_apply(), operands, + output); }); }; std::vector xor_masks; @@ -1595,18 +1441,17 @@ Status IrEmitterUnnested::EmitMlirSort(MlirEmitterInput input) { TF_RETURN_IF_ERROR(emit_kernel(xor_masks)); } VLOG(2) << absl::StreamFormat( - "%s requires %d thunks (including any D2D copies)", input.name, + "%s requires %d thunks (including any D2D copies)", sort->name(), thunks.size()); - AddThunkToThunkSequence( - absl::make_unique(input.thunk_info, std::move(thunks))); - if (operand_count > 1) { + AddThunkToThunkSequence(absl::make_unique( + GetThunkInfo(sort), std::move(thunks))); + if (sort->operand_count() > 1) { // Emit the tuple as part of the last stage of sorting. // We are currently in the block sorted.in_bounds.after. b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple( - ir_arrays[operand_count], - absl::MakeSpan(ir_arrays).subspan(0, ir_arrays.size() - 1), &b_); + llvm_ir::EmitTuple(GetIrArray(*sort, *sort), + ConstructIrArrayForOutputs(*sort), &b_); } return Status::OK(); } @@ -1744,6 +1589,24 @@ Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) { return Status::OK(); } +// Describes how to access a particular subshape for an HLO. For instance if +// `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at +// ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is found +// at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we +// dereference twice -- first at index 3, and then at index 4 -- to get the +// address of our buffer. +struct HloBufferSlice { + const HloInstruction* instr; + ShapeIndex hlo_index; + + // The root buffer to look at. + BufferAllocation::Slice buffer_slice; + + // Describes how to dereference starting at that buffer to get to the buffer + // in question. + ShapeIndex gte_index; +}; + // Figures out how to access the buffers for all subshapes of hlo's operands and // for hlo itself (i.e. all the buffers produced by HLO). // @@ -1852,22 +1715,22 @@ static std::vector GetHloBufferSlices( return result; } -std::unique_ptr -IrEmitterUnnested::BuildKernelThunkFromBufferSlices( - absl::string_view name, Thunk::ThunkInfo thunk_info, - absl::Span slices, - std::function - bind_slice_to_ir_value) { - const auto& buffer_assn = ir_emitter_context_->buffer_assignment(); +std::unique_ptr IrEmitterUnnested::BuildKernelThunk( + const HloInstruction* inst, bool implements_whole_instruction) { + const BufferAssignment& buffer_assn = + ir_emitter_context_->buffer_assignment(); + + std::vector hlo_slices = + GetHloBufferSlices(inst, buffer_assn); // Figure out which buffer allocations need to be passed as arguments to our - // kernel. This is simply all of the allocations referenced in slices, + // kernel. This is simply all of the allocations referenced in hlo_slices, // plus the XLA temp buffer (if we have it). We always include the temp // buffer because even if the kernel itself doesn't use it, a nested // subcomputation within the kernel (e.g. a kMap's computation) might. std::unordered_set buffers_needed; - for (auto* slice : slices) { - buffers_needed.insert(slice->buffer_slice.allocation()); + for (const auto& hlo_buffer_slice : hlo_slices) { + buffers_needed.insert(hlo_buffer_slice.buffer_slice.allocation()); } absl::optional temp_buffer; for (const BufferAllocation& alloc : buffer_assn.Allocations()) { @@ -1896,7 +1759,7 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices( return a->index() < b->index(); }); - llvm::Function* kernel = BuildKernelPrototype(name, non_constant_buffers); + llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers); // Build a map from a BufferAllocation to the corresponding argument in our // kernel. @@ -1930,19 +1793,24 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices( // For each buffer our kernel might want to touch, bind it to a value derived // from our kernel args. - for (auto* slice : slices) { - const BufferAllocation::Slice& buffer_slice = slice->buffer_slice; - const ShapeIndex& gte_index = slice->gte_index; + for (const auto& hlo_buffer_slice : hlo_slices) { + const HloInstruction* instr = hlo_buffer_slice.instr; + const ShapeIndex& index = hlo_buffer_slice.hlo_index; + const BufferAllocation::Slice& slice = hlo_buffer_slice.buffer_slice; + const ShapeIndex& gte_index = hlo_buffer_slice.gte_index; + + VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString() + << " is found in slice " << slice.ToString() << " at GTE index " + << gte_index.ToString(); llvm::Value* loc; - if (buffer_slice.allocation()->is_constant()) { + if (slice.allocation()->is_constant()) { loc = ir_emitter_context_->llvm_module()->getGlobalVariable( - llvm_ir::ConstantBufferAllocationToGlobalName( - *buffer_slice.allocation())); + llvm_ir::ConstantBufferAllocationToGlobalName(*slice.allocation())); CHECK_NE(loc, nullptr); } else { - loc = InBoundsGEP(kernel_args.at(buffer_slice.allocation()), - {b_.getInt64(buffer_slice.offset())}); + loc = InBoundsGEP(kernel_args.at(slice.allocation()), + {b_.getInt64(slice.offset())}); } // If gte_index is nonempty, we have to dereference `loc` to get to the @@ -1954,7 +1822,7 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices( loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)})); } - bind_slice_to_ir_value(slice, loc); + bindings_.BindHloToIrValue(*instr, loc, index); } // Bind the temp buffer so that nested subcomputations can find it if they @@ -1966,66 +1834,9 @@ IrEmitterUnnested::BuildKernelThunkFromBufferSlices( llvm::ConstantPointerNull::get(b_.getInt8PtrTy())); } - return absl::make_unique(thunk_info, non_constant_buffers, - std::string(kernel->getName())); -} - -std::unique_ptr IrEmitterUnnested::BuildKernelThunk( - const HloInstruction* inst, bool implements_whole_instruction) { - std::vector hlo_slices = - GetHloBufferSlices(inst, ir_emitter_context_->buffer_assignment()); - - std::vector slice_ptrs; - slice_ptrs.reserve(hlo_slices.size()); - for (auto& slice : hlo_slices) { - slice_ptrs.push_back(&slice); - } - - return BuildKernelThunkFromBufferSlices( - inst->name(), + return absl::make_unique( implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(), - slice_ptrs, [this](const BufferSlice* slice, llvm::Value* value) { - const HloBufferSlice* hlo_buffer_slice = - static_cast(slice); - const HloInstruction* instr = hlo_buffer_slice->instr; - const ShapeIndex& index = hlo_buffer_slice->hlo_index; - VLOG(3) << "Buffer for " << instr->ToString() << " at " - << index.ToString() << " is found in slice " - << hlo_buffer_slice->buffer_slice.ToString() << " at GTE index " - << hlo_buffer_slice->gte_index.ToString(); - - bindings_.BindHloToIrValue(*instr, value, index); - }); -} - -std::unique_ptr IrEmitterUnnested::BuildKernelThunkForMlir( - absl::string_view name, Thunk::ThunkInfo thunk_info, - absl::Span slices, - std::vector* ir_arrays) { - absl::flat_hash_set buffers_written; - std::vector slice_ptrs; - slice_ptrs.reserve(slices.size()); - for (auto& slice : slices) { - slice_ptrs.push_back(&slice); - if (slice.written) { - buffers_written.insert(slice.buffer_slice); - } - } - - ir_arrays->clear(); - return BuildKernelThunkFromBufferSlices( - name, thunk_info, slice_ptrs, - [&](const BufferSlice* slice, llvm::Value* value) { - const auto& mlir_slice = static_cast(*slice); - - llvm_ir::IrArray ir_array( - CastToTypedValue(mlir_slice.shape, value, &b_), mlir_slice.shape); - if (!buffers_written.contains(slice->buffer_slice)) { - ir_array.MarkInvariantOverWholeProgram(&value->getContext()); - } - - ir_arrays->push_back(ir_array); - }); + non_constant_buffers, std::string(kernel->getName())); } StatusOr> IrEmitterUnnested::BuildInitializerThunk( @@ -2232,7 +2043,7 @@ Status CheckConditionalBuffersShareAllocation( } // namespace -StatusOr> IrEmitterUnnested::BuildWhileThunk( +std::unique_ptr IrEmitterUnnested::BuildWhileThunk( const HloInstruction* hlo) { // Check that all while-related buffers share an allocation. TF_CHECK_OK(CheckWhileBuffersShareAllocation( @@ -2240,26 +2051,24 @@ StatusOr> IrEmitterUnnested::BuildWhileThunk( // Generate thunk sequence for while 'condition'. HloComputation* condition = hlo->while_condition(); - TF_ASSIGN_OR_RETURN(auto ir_emitter_condition, - IrEmitterUnnested::Create(hlo_module_config_, condition, - ir_emitter_context_)); - TF_RETURN_IF_ERROR(condition->Accept(ir_emitter_condition.get())); + IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition, + ir_emitter_context_); + TF_CHECK_OK(condition->Accept(&ir_emitter_condition)); // Generate thunk sequence for while 'body'. HloComputation* body = hlo->while_body(); - TF_ASSIGN_OR_RETURN( - auto ir_emitter_body, - IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); - TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); + IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, + ir_emitter_context_); + TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return std::unique_ptr(new WhileThunk( + return absl::make_unique( GetThunkInfo(hlo), GetAllocationSlice(*condition->root_instruction()), // cond result - ir_emitter_condition->ConsumeThunkSequence(), - ir_emitter_body->ConsumeThunkSequence())); + ir_emitter_condition.ConsumeThunkSequence(), + ir_emitter_body.ConsumeThunkSequence()); } -StatusOr> IrEmitterUnnested::BuildForThunk( +std::unique_ptr IrEmitterUnnested::BuildForThunk( const HloInstruction* hlo, const int64 loop_limit) { // Check that all while-related buffers share an allocation. TF_CHECK_OK(CheckWhileBuffersShareAllocation( @@ -2267,16 +2076,15 @@ StatusOr> IrEmitterUnnested::BuildForThunk( // Generate thunk sequence for while 'body' (will be used a For loop body). HloComputation* body = hlo->while_body(); - TF_ASSIGN_OR_RETURN( - auto ir_emitter_body, - IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_)); - TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get())); + IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, + ir_emitter_context_); + TF_CHECK_OK(body->Accept(&ir_emitter_body)); - return std::unique_ptr(new ForThunk( - GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence())); + return absl::make_unique(GetThunkInfo(hlo), loop_limit, + ir_emitter_body.ConsumeThunkSequence()); } -StatusOr> IrEmitterUnnested::BuildConditionalThunk( +std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( const HloInstruction* hlo) { // Check that the buffers used in conditional are shared with the operands and // result appropriately. @@ -2288,17 +2096,15 @@ StatusOr> IrEmitterUnnested::BuildConditionalThunk( for (int j = 0; j < hlo->branch_count(); ++j) { branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1))); HloComputation* branch_computation = hlo->branch_computation(j); - TF_ASSIGN_OR_RETURN( - auto ir_emitter, - IrEmitterUnnested::Create(hlo_module_config_, branch_computation, - ir_emitter_context_)); - TF_CHECK_OK(branch_computation->Accept(ir_emitter.get())); - branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence())); + IrEmitterUnnested ir_emitter(hlo_module_config_, branch_computation, + ir_emitter_context_); + TF_CHECK_OK(branch_computation->Accept(&ir_emitter)); + branch_thunks.push_back(std::move(*ir_emitter.ConsumeThunkSequence())); } - return std::unique_ptr(new ConditionalThunk( + return absl::make_unique( GetThunkInfo(hlo), GetAllocationSlice(*hlo->operand(0)), branch_operands, - std::move(branch_thunks))); + std::move(branch_thunks)); } Status IrEmitterUnnested::EmitTargetElementLoopInThunk( diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index b9146dd8fae..019fcdf21db 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ #include "absl/container/inlined_vector.h" -#include "tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" @@ -29,40 +28,6 @@ limitations under the License. namespace xla { namespace gpu { -struct BufferSlice { - // The root buffer to look at. - BufferAllocation::Slice buffer_slice; - - // Describes how to dereference starting at that buffer to get to the buffer - // in question. - ShapeIndex gte_index; -}; - -// Describes how to access a particular subshape for an HLO. For instance if -// `.hlo_index` is {1} and `.gte_index` is {3, 4} then buffer for `.instr` at -// ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) is -// found at `.buffer_slice`[3][4]. That is, `.slice` is a void***, which we -// dereference twice -- first at index 3, and then at index 4 -- to get the -// address of our buffer. -struct HloBufferSlice : public BufferSlice { - const HloInstruction* instr; - ShapeIndex hlo_index; -}; - -struct MlirBufferSlice : public BufferSlice { - // The buffer is modified by the kernel. - bool written; - - Shape shape; -}; - -struct MlirEmitterInput { - mlir::Operation* op; - absl::string_view name; - Thunk::ThunkInfo thunk_info; - MlirBufferSlice extra_slice; -}; - // Emits LLVM IR for an "unnested computation". // // An unnested computation is an HloComputation which you run by executing one @@ -124,14 +89,12 @@ class IrEmitterUnnested : public IrEmitter, const string& loop_name, llvm::Value* tile_height, llvm::Value* tile_width, KernelSupportLibrary* ksl)>; + IrEmitterUnnested(const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + IrEmitterContext* ir_emitter_context); IrEmitterUnnested(const IrEmitterUnnested&) = delete; IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; - static StatusOr> Create( - const HloModuleConfig& hlo_module_config, - const HloComputation* hlo_computation, - IrEmitterContext* ir_emitter_context); - // Transfers the ownship of thunk_sequence_ out. std::unique_ptr ConsumeThunkSequence() { return std::make_unique(std::move(thunk_sequence_)); @@ -161,7 +124,6 @@ class IrEmitterUnnested : public IrEmitter, Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; - Status EmitMlirSort(MlirEmitterInput input); Status HandleTriangularSolve(HloInstruction* hlo) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleAllReduce(HloInstruction* crs) override; @@ -186,10 +148,6 @@ class IrEmitterUnnested : public IrEmitter, Status Postprocess(HloInstruction* hlo) override; private: - IrEmitterUnnested(const HloModuleConfig& hlo_module_config, - const HloComputation* hlo_computation, - IrEmitterContext* ir_emitter_context); - // Add a owning Thunk object to the thunk sequence. void AddThunkToThunkSequence(std::unique_ptr thunk) override { thunk_sequence_.emplace_back(std::move(thunk)); @@ -306,7 +264,8 @@ class IrEmitterUnnested : public IrEmitter, // Builds the prototype of the IR kernel for `inst` and adds it to the module. // This kernel takes as arguments pointers to the given buffer allocations. llvm::Function* BuildKernelPrototype( - absl::string_view name, absl::Span args); + const HloInstruction& inst, + absl::Span args); // Helper for writing extra outputs from inside a reduce kernel. Status EmitExtraOutputsForReduce( @@ -531,12 +490,6 @@ class IrEmitterUnnested : public IrEmitter, HloComputation* reducer, llvm::Type* element_type, llvm::Value* partial_result_address); - std::unique_ptr BuildKernelThunkFromBufferSlices( - absl::string_view name, Thunk::ThunkInfo thunk_info, - absl::Span slices, - std::function - bind_slice_to_ir_value); - // Returns a KernelThunk that invokes the kernel emitted for `inst`. The // caller needs to make sure `inst` outlives the lifetime of the returned // Thunk object. 'implements_whole_instruction' specifies whether this @@ -545,11 +498,6 @@ class IrEmitterUnnested : public IrEmitter, std::unique_ptr BuildKernelThunk( const HloInstruction* inst, bool implements_whole_instruction); - std::unique_ptr BuildKernelThunkForMlir( - absl::string_view name, Thunk::ThunkInfo thunk_info, - absl::Span slices, - std::vector* ir_arrays); - // Returns a thunk that, given a reduce or select-and-scatter op, // initializes its memory to the appropriate initial value. StatusOr> BuildInitializerThunk( @@ -557,18 +505,17 @@ class IrEmitterUnnested : public IrEmitter, // Returns a WhileThunk that invokes thunk sequences for 'condition' and // 'body' sub-computations of while instruction 'hlo'. - StatusOr> BuildWhileThunk(const HloInstruction* hlo); + std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); // Returns a ForThunk which executes 'loop_limit' invocations of a thunk // sequence from the 'body' sub-computation of the while instruction 'hlo'. - StatusOr> BuildForThunk(const HloInstruction* hlo, - const int64 loop_limit); + std::unique_ptr BuildForThunk(const HloInstruction* hlo, + const int64 loop_limit); // Returns a ConditionalThunk which executes the thunk sequence for the // 'branch_computation' corresponding to the predicate/branch_index of the // given conditional instruction. - StatusOr> BuildConditionalThunk( - const HloInstruction* hlo); + std::unique_ptr BuildConditionalThunk(const HloInstruction* hlo); // Emits current thread id with the given type. // @@ -598,9 +545,6 @@ class IrEmitterUnnested : public IrEmitter, absl::optional thread_id_filter = absl::nullopt, absl::optional block_id_filter = absl::nullopt); - StatusOr GetOrCreateSubComputationFromRegion( - mlir::Region* region); - // Returns the last generated thunk. Thunk* LastThunk() const { return thunk_sequence_.back().get(); } @@ -611,14 +555,6 @@ class IrEmitterUnnested : public IrEmitter, // The HloComputation that this IrEmitter emits code for. const HloComputation* hlo_computation_; - - mlir::OwningModuleRef mlir_scratch_module_; - - // This is for cache-purpose only. It has no significant semantics. - mlir::LhloDialectEmitter lhlo_scratch_emitter_; - - absl::flat_hash_map> - scratch_nested_computations_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index 809b277317f..a2bddd2d0d7 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -458,35 +458,6 @@ xla_test( ], ) -tf_cc_test( - name = "sorting_test", - srcs = [ - "sorting_test.cc", - ], - tags = tf_cuda_tests_tags() + [ - "no_rocm", - ], - deps = [ - ":gpu_codegen_test", - "//tensorflow/compiler/xla:debug_options_flags", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:xla_proto_cc", - "//tensorflow/compiler/xla/service:gpu_plugin", - "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/compiler/xla/service:hlo_parser", - "//tensorflow/compiler/xla/service/gpu:gpu_executable", - "//tensorflow/compiler/xla/tests:filecheck", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:llvm_irgen_test_base", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/stream_executor/lib", - "@com_google_absl//absl/memory", - ], -) - tf_cc_binary( name = "hlo_to_llvm_ir", srcs = ["hlo_to_llvm_ir.cc"], diff --git a/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo b/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo index 4d29a8df116..272c9a25769 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo +++ b/tensorflow/compiler/xla/service/gpu/tests/sorting.hlo @@ -8,162 +8,162 @@ compare { ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT } -// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) +// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) // CHECK-NEXT: entry: // CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 -// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 -// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 -// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64 -// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 -// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64 -// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 -// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]] +// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 +// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0 +// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 +// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] // CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 // CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) -// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 -// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2 -// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 -// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 -// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 +// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] // CHECK: sort.in_bounds-after: // CHECK-NEXT: ret void // CHECK: sort.in_bounds-true: -// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP8]], 2 -// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1 -// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]] -// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3 -// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]] -// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2 +// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1 +// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]] +// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3 +// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]] +// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] // CHECK: smaller_comparison_index-after: // CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] // CHECK: smaller_comparison_index-true: -// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]] -// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] -// CHECK-NEXT: call void @region_0_4(float* [[TMP16]], float* [[TMP17]], i8* [[COMPARE_RETURN_BUFFER]]) -// CHECK-NEXT: [[TMP18:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 -// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP18]], 0 +// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: call void @compare(float* [[TMP12]], float* [[TMP13]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP14:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP14]], 0 // CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] // CHECK: is_smaller_than-after: // CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] // CHECK: is_smaller_than-true: -// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4 -// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP17]], align 4 -// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] -// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4 -// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]] -// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4 +// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4 +// CHECK-NEXT: [[TMP16:%.*]] = load float, float* [[TMP13]], align 4 +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: store float [[TMP16]], float* [[TMP18]], align 4 // CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] -// CHECK: define internal void @region_0_4(float* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]]) +// CHECK: define internal void @compare(float* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]]) // CHECK-NEXT: entry: -// CHECK-NEXT: [[COMPARE_3_TYPED:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[ARG_0_1_TYPED:%.*]], align 4 -// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARG_1_2_TYPED:%.*]], align 4 +// CHECK-NEXT: [[LT_TYPED:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[P_0_LHS_TYPED]], align 4 +// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[P_0_RHS_TYPED]], align 4 // CHECK-NEXT: [[TMP2:%.*]] = fcmp olt float [[TMP0]], [[TMP1]] // CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i8 -// CHECK-NEXT: store i8 [[TMP3]], i8* [[COMPARE_3_TYPED]], align 1 -// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[COMPARE_3_TYPED]], align 1 -// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG:%.*]], align 1 +// CHECK-NEXT: store i8 [[TMP3]], i8* [[LT_TYPED]], align 1 +// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[LT_TYPED]], align 1 +// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG]], align 1 // CHECK-NEXT: ret void -// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) { +// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) { // CHECK-NEXT: entry: // CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 -// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 -// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 -// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64 -// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 -// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64 -// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 -// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]] +// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 +// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0 +// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 +// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] // CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 // CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) -// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 -// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2 -// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 -// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 -// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 +// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] // CHECK: sort.in_bounds-after: // CHECK-NEXT: ret void // CHECK: sort.in_bounds-true: -// CHECK-NEXT: [[TMP11:%.*]] = xor i64 [[TMP8]], 3 -// CHECK-NEXT: [[TMP12:%.*]] = icmp slt i64 [[TMP8]], [[TMP11]] -// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], 3 -// CHECK-NEXT: [[TMP14:%.*]] = and i1 [[TMP12]], [[TMP13]] -// CHECK-NEXT: br i1 [[TMP14]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK-NEXT: [[TMP7:%.*]] = xor i64 [[TMP4]], 3 +// CHECK-NEXT: [[TMP8:%.*]] = icmp slt i64 [[TMP4]], [[TMP7]] +// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], 3 +// CHECK-NEXT: [[TMP10:%.*]] = and i1 [[TMP8]], [[TMP9]] +// CHECK-NEXT: br i1 [[TMP10]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] // CHECK: smaller_comparison_index-after: // CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] // CHECK: smaller_comparison_index-true: -// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] -// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP8]] -// CHECK-NEXT: call void @region_0_4(float* [[TMP15]], float* [[TMP16]], i8* [[COMPARE_RETURN_BUFFER]]) -// CHECK-NEXT: [[TMP17:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 -// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP17]], 0 +// CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP4]] +// CHECK-NEXT: call void @compare(float* [[TMP11]], float* [[TMP12]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP13:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP13]], 0 // CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] // CHECK: is_smaller_than-after: // CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] // CHECK: is_smaller_than-true: -// CHECK-NEXT: [[TMP18:%.*]] = load float, float* [[TMP15]], align 4 -// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4 -// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP8]] -// CHECK-NEXT: store float [[TMP18]], float* [[TMP20]], align 4 -// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] -// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4 +// CHECK-NEXT: [[TMP14:%.*]] = load float, float* [[TMP11]], align 4 +// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4 +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP4]] +// CHECK-NEXT: store float [[TMP14]], float* [[TMP16]], align 4 +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4 // CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] -// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]]) { +// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC1:%.*]]) { // CHECK-NEXT: entry: // CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 -// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 -// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP4:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 -// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP4]] to i64 -// CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 -// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP5]] to i64 -// CHECK-NEXT: [[TMP6:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 -// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP6]], [[THREAD_ID]] +// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 +// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 +// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] // CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 // CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) -// CHECK-NEXT: [[TMP7:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 -// CHECK-NEXT: [[TMP8:%.*]] = urem i64 [[TMP7]], 2 -// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 -// CHECK-NEXT: [[TMP10:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 -// CHECK-NEXT: br i1 [[TMP10]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 +// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] // CHECK: sort.in_bounds-after: // CHECK-NEXT: ret void // CHECK: sort.in_bounds-true: -// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP8]], 2 -// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1 -// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]] -// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3 -// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]] -// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2 +// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1 +// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]] +// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3 +// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]] +// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] // CHECK: smaller_comparison_index-after: // CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] // CHECK: smaller_comparison_index-true: -// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]] -// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] -// CHECK-NEXT: call void @region_0_4(float* [[TMP16]], float* [[TMP17]], i8* [[COMPARE_RETURN_BUFFER]]) -// CHECK-NEXT: [[TMP18:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 -// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP18]], 0 +// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: call void @compare(float* [[TMP12]], float* [[TMP13]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP14:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP14]], 0 // CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] // CHECK: is_smaller_than-after: // CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] // CHECK: is_smaller_than-true: -// CHECK-NEXT: [[TMP19:%.*]] = load float, float* [[TMP16]], align 4 -// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP17]], align 4 -// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP11]] -// CHECK-NEXT: store float [[TMP19]], float* [[TMP21]], align 4 -// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP1]], i64 0, i64 [[TMP9]], i64 [[TMP12]] -// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4 +// CHECK-NEXT: [[TMP15:%.*]] = load float, float* [[TMP12]], align 4 +// CHECK-NEXT: [[TMP16:%.*]] = load float, float* [[TMP13]], align 4 +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store float [[TMP15]], float* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: store float [[TMP16]], float* [[TMP18]], align 4 // CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] ENTRY main { x = f32[2, 3] parameter(0) @@ -182,198 +182,210 @@ compare { ROOT lt = pred[] compare(p.1.lhs, p.1.rhs), direction=LT } -// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) +// CHECK: define void @sort(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) // CHECK-NEXT: entry: // CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 -// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]* -// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 -// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 -// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]* -// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 -// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64 -// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 -// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64 -// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 -// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]] +// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4]], i64 0 +// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]* +// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0]], i64 0 +// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1]], i64 0 +// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]* +// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2]], i64 0 +// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3]], i64 0 +// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 +// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] // CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 // CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) -// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 -// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2 -// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 -// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 -// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 +// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] // CHECK: sort.in_bounds-after: // CHECK-NEXT: ret void // CHECK: sort.in_bounds-true: -// CHECK-NEXT: [[TMP13:%.*]] = mul i64 [[TMP10]], 2 -// CHECK-NEXT: [[TMP14:%.*]] = xor i64 [[TMP13]], 1 -// CHECK-NEXT: [[TMP15:%.*]] = icmp slt i64 [[TMP13]], [[TMP14]] -// CHECK-NEXT: [[TMP16:%.*]] = icmp slt i64 [[TMP14]], 3 -// CHECK-NEXT: [[TMP17:%.*]] = and i1 [[TMP15]], [[TMP16]] -// CHECK-NEXT: br i1 [[TMP17]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK-NEXT: [[TMP7:%.*]] = mul i64 [[TMP4]], 2 +// CHECK-NEXT: [[TMP8:%.*]] = xor i64 [[TMP7]], 1 +// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], [[TMP8]] +// CHECK-NEXT: [[TMP10:%.*]] = icmp slt i64 [[TMP8]], 3 +// CHECK-NEXT: [[TMP11:%.*]] = and i1 [[TMP9]], [[TMP10]] +// CHECK-NEXT: br i1 [[TMP11]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] // CHECK: smaller_comparison_index-after: // CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] // CHECK: smaller_comparison_index-true: -// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP14]] -// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]] -// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP14]] -// CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]] -// CHECK-NEXT: call void @region_0_6(i32* [[TMP18]], i32* [[TMP19]], float* [[TMP20]], float* [[TMP21]], i8* [[COMPARE_RETURN_BUFFER]]) -// CHECK-NEXT: [[TMP22:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 -// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP22]], 0 +// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: [[TMP15:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: call void @compare(i32* [[TMP12]], i32* [[TMP13]], float* [[TMP14]], float* [[TMP15]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP16:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP16]], 0 // CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] // CHECK: is_smaller_than-after: // CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] // CHECK: is_smaller_than-true: -// CHECK-NEXT: [[TMP23:%.*]] = load i32, i32* [[TMP18]], align 4 -// CHECK-NEXT: [[TMP24:%.*]] = load i32, i32* [[TMP19]], align 4 -// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]] -// CHECK-NEXT: store i32 [[TMP23]], i32* [[TMP25]], align 4 -// CHECK-NEXT: [[TMP26:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP14]] -// CHECK-NEXT: store i32 [[TMP24]], i32* [[TMP26]], align 4 -// CHECK-NEXT: [[TMP27:%.*]] = load float, float* [[TMP20]], align 4 -// CHECK-NEXT: [[TMP28:%.*]] = load float, float* [[TMP21]], align 4 -// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]] -// CHECK-NEXT: store float [[TMP27]], float* [[TMP29]], align 4 -// CHECK-NEXT: [[TMP30:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP14]] -// CHECK-NEXT: store float [[TMP28]], float* [[TMP30]], align 4 +// CHECK-NEXT: [[TMP17:%.*]] = load i32, i32* [[TMP12]], align 4 +// CHECK-NEXT: [[TMP18:%.*]] = load i32, i32* [[TMP13]], align 4 +// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store i32 [[TMP17]], i32* [[TMP19]], align 4 +// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: store i32 [[TMP18]], i32* [[TMP20]], align 4 +// CHECK-NEXT: [[TMP21:%.*]] = load float, float* [[TMP14]], align 4 +// CHECK-NEXT: [[TMP22:%.*]] = load float, float* [[TMP15]], align 4 +// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store float [[TMP21]], float* [[TMP23]], align 4 +// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP8]] +// CHECK-NEXT: store float [[TMP22]], float* [[TMP24]], align 4 // CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] -// CHECK: define internal void @region_0_6(i32* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], i32* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]]) +// CHECK: define internal void @compare(i32* dereferenceable(4) [[P_0_LHS_TYPED:%.*]], i32* dereferenceable(4) [[P_0_RHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_LHS_TYPED:%.*]], float* dereferenceable(4) [[P_1_RHS_TYPED:%.*]], i8* dereferenceable(1) [[OUTPUT_ARG:%.*]]) // CHECK-NEXT: entry: -// CHECK-NEXT: [[COMPARE_5_TYPED:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[ARG_2_3_TYPED:%.*]], align 4 -// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[ARG_3_4_TYPED:%.*]], align 4 +// CHECK-NEXT: [[LT_TYPED:%.*]] = alloca i8, align 1 +// CHECK-NEXT: [[TMP0:%.*]] = load float, float* [[P_1_LHS_TYPED]], align 4 +// CHECK-NEXT: [[TMP1:%.*]] = load float, float* [[P_1_RHS_TYPED]], align 4 // CHECK-NEXT: [[TMP2:%.*]] = fcmp olt float [[TMP0]], [[TMP1]] // CHECK-NEXT: [[TMP3:%.*]] = zext i1 [[TMP2]] to i8 -// CHECK-NEXT: store i8 [[TMP3]], i8* [[COMPARE_5_TYPED]], align 1 -// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[COMPARE_5_TYPED]], align 1 -// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG:%.*]], align 1 +// CHECK-NEXT: store i8 [[TMP3]], i8* [[LT_TYPED]], align 1 +// CHECK-NEXT: [[LOAD_RET_VALUE:%.*]] = load i8, i8* [[LT_TYPED]], align 1 +// CHECK-NEXT: store i8 [[LOAD_RET_VALUE]], i8* [[OUTPUT_ARG]], align 1 // CHECK-NEXT: ret void -// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) +// CHECK: define void @sort__1(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) // CHECK-NEXT: entry: // CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 -// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]* -// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 -// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 -// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]* -// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 -// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64 -// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 -// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64 -// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 -// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]] +// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]* +// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]* +// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2:%.*]], i64 0 +// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3:%.*]], i64 0 +// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 +// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] // CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 // CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) -// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 -// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2 -// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 -// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 -// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 +// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] // CHECK: sort.in_bounds-after: // CHECK-NEXT: ret void // CHECK: sort.in_bounds-true: -// CHECK-NEXT: [[TMP13:%.*]] = xor i64 [[TMP10]], 3 -// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP10]], [[TMP13]] -// CHECK-NEXT: [[TMP15:%.*]] = icmp slt i64 [[TMP13]], 3 -// CHECK-NEXT: [[TMP16:%.*]] = and i1 [[TMP14]], [[TMP15]] -// CHECK-NEXT: br i1 [[TMP16]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK-NEXT: [[TMP7:%.*]] = xor i64 [[TMP4]], 3 +// CHECK-NEXT: [[TMP8:%.*]] = icmp slt i64 [[TMP4]], [[TMP7]] +// CHECK-NEXT: [[TMP9:%.*]] = icmp slt i64 [[TMP7]], 3 +// CHECK-NEXT: [[TMP10:%.*]] = and i1 [[TMP8]], [[TMP9]] +// CHECK-NEXT: br i1 [[TMP10]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] // CHECK: smaller_comparison_index-after: // CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] // CHECK: smaller_comparison_index-true: -// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]] -// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP10]] -// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]] -// CHECK-NEXT: [[TMP20:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP10]] -// CHECK-NEXT: call void @region_0_6(i32* [[TMP17]], i32* [[TMP18]], float* [[TMP19]], float* [[TMP20]], i8* [[COMPARE_RETURN_BUFFER]]) -// CHECK-NEXT: [[TMP21:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 -// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP21]], 0 +// CHECK-NEXT: [[TMP11:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: [[TMP12:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP4]] +// CHECK-NEXT: [[TMP13:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP4]] +// CHECK-NEXT: call void @compare(i32* [[TMP11]], i32* [[TMP12]], float* [[TMP13]], float* [[TMP14]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP15:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP15]], 0 // CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] // CHECK: is_smaller_than-after: // CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] // CHECK: is_smaller_than-true: -// CHECK-NEXT: [[TMP22:%.*]] = load i32, i32* [[TMP17]], align 4 -// CHECK-NEXT: [[TMP23:%.*]] = load i32, i32* [[TMP18]], align 4 -// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP10]] -// CHECK-NEXT: store i32 [[TMP22]], i32* [[TMP24]], align 4 -// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP13]] -// CHECK-NEXT: store i32 [[TMP23]], i32* [[TMP25]], align 4 -// CHECK-NEXT: [[TMP26:%.*]] = load float, float* [[TMP19]], align 4 -// CHECK-NEXT: [[TMP27:%.*]] = load float, float* [[TMP20]], align 4 -// CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP10]] -// CHECK-NEXT: store float [[TMP26]], float* [[TMP28]], align 4 -// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP13]] -// CHECK-NEXT: store float [[TMP27]], float* [[TMP29]], align 4 +// CHECK-NEXT: [[TMP16:%.*]] = load i32, i32* [[TMP11]], align 4 +// CHECK-NEXT: [[TMP17:%.*]] = load i32, i32* [[TMP12]], align 4 +// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP4]] +// CHECK-NEXT: store i32 [[TMP16]], i32* [[TMP18]], align 4 +// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store i32 [[TMP17]], i32* [[TMP19]], align 4 +// CHECK-NEXT: [[TMP20:%.*]] = load float, float* [[TMP13]], align 4 +// CHECK-NEXT: [[TMP21:%.*]] = load float, float* [[TMP14]], align 4 +// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP4]] +// CHECK-NEXT: store float [[TMP20]], float* [[TMP22]], align 4 +// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP7]] +// CHECK-NEXT: store float [[TMP21]], float* [[TMP23]], align 4 // CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] -// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) +// CHECK: define void @sort__2(i8* noalias align 64 dereferenceable(24) [[ALLOC0:%.*]], i8* noalias align 64 dereferenceable(24) [[ALLOC1:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC2:%.*]], i8* noalias align 16 dereferenceable(24) [[ALLOC3:%.*]], i8* noalias align 64 dereferenceable(16) [[ALLOC4:%.*]]) // CHECK-NEXT: entry: // CHECK-NEXT: [[COMPARE_RETURN_BUFFER:%.*]] = alloca i8, align 1 -// CHECK-NEXT: [[TMP0:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 -// CHECK-NEXT: [[TMP1:%.*]] = bitcast i8* [[TMP0]] to [2 x [3 x i32]]* -// CHECK-NEXT: [[TMP2:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 -// CHECK-NEXT: [[TMP3:%.*]] = bitcast i8* [[TMP2]] to [2 x [3 x float]]* -// CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 -// CHECK-NEXT: [[TMP5:%.*]] = bitcast i8* [[TMP4]] to [2 x i8*]* -// CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 -// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP6]] to i64 -// CHECK-NEXT: [[TMP7:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 -// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP7]] to i64 -// CHECK-NEXT: [[TMP8:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 -// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP8]], [[THREAD_ID]] +// CHECK-NEXT: [[SORT_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC4:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED:%.*]] = bitcast i8* [[SORT_RAW]] to [2 x i8*]* +// CHECK-NEXT: [[SORT_RAW1:%.*]] = getelementptr inbounds i8, i8* [[ALLOC0:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED2:%.*]] = bitcast i8* [[SORT_RAW1]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[SORT_RAW3:%.*]] = getelementptr inbounds i8, i8* [[ALLOC1:%.*]], i64 0 +// CHECK-NEXT: [[SORT_TYPED4:%.*]] = bitcast i8* [[SORT_RAW3]] to [2 x [3 x float]]* +// CHECK-NEXT: [[X_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC2:%.*]], i64 0 +// CHECK-NEXT: [[X_TYPED:%.*]] = bitcast i8* [[X_RAW]] to [2 x [3 x i32]]* +// CHECK-NEXT: [[Y_RAW:%.*]] = getelementptr inbounds i8, i8* [[ALLOC3:%.*]], i64 0 +// CHECK-NEXT: [[Y_TYPED:%.*]] = bitcast i8* [[Y_RAW]] to [2 x [3 x float]]* +// CHECK-NEXT: [[TMP0:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x(), !range !6 +// CHECK-NEXT: [[BLOCK_ID:%.*]] = zext i32 [[TMP0]] to i64 +// CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.nvvm.read.ptx.sreg.tid.x(), !range !7 +// CHECK-NEXT: [[THREAD_ID:%.*]] = zext i32 [[TMP1]] to i64 +// CHECK-NEXT: [[TMP2:%.*]] = mul nuw nsw i64 [[BLOCK_ID]], 4 +// CHECK-NEXT: [[LINEAR_INDEX:%.*]] = add nuw nsw i64 [[TMP2]], [[THREAD_ID]] // CHECK-NEXT: [[LINEAR_INDEX_IN_RANGE:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 // CHECK-NEXT: call void @llvm.assume(i1 [[LINEAR_INDEX_IN_RANGE]]) -// CHECK-NEXT: [[TMP9:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 -// CHECK-NEXT: [[TMP10:%.*]] = urem i64 [[TMP9]], 2 -// CHECK-NEXT: [[TMP11:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 -// CHECK-NEXT: [[TMP12:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 -// CHECK-NEXT: br i1 [[TMP12]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] +// CHECK-NEXT: [[TMP3:%.*]] = udiv i64 [[LINEAR_INDEX]], 1 +// CHECK-NEXT: [[TMP4:%.*]] = urem i64 [[TMP3]], 2 +// CHECK-NEXT: [[TMP5:%.*]] = udiv i64 [[LINEAR_INDEX]], 2 +// CHECK-NEXT: [[TMP6:%.*]] = icmp ult i64 [[LINEAR_INDEX]], 4 +// CHECK-NEXT: br i1 [[TMP6]], label [[SORT_IN_BOUNDS_TRUE:%.*]], label [[SORT_IN_BOUNDS_AFTER:%.*]] // CHECK: sort.in_bounds-after: -// CHECK-NEXT: [[TMP13:%.*]] = bitcast [2 x [3 x i32]]* [[TMP1]] to i8* -// CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 0 -// CHECK-NEXT: store i8* [[TMP13]], i8** [[TMP14]], align 8 -// CHECK-NEXT: [[TMP15:%.*]] = bitcast [2 x [3 x float]]* [[TMP3]] to i8* -// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[TMP5]], i64 0, i64 1 -// CHECK-NEXT: store i8* [[TMP15]], i8** [[TMP16]], align 8 +// CHECK-NEXT: [[TMP7:%.*]] = bitcast [2 x [3 x i32]]* [[SORT_TYPED2]] to i8* +// CHECK-NEXT: [[TMP8:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[SORT_TYPED]], i64 0, i64 0 +// CHECK-NEXT: store i8* [[TMP7]], i8** [[TMP8]], align 8 +// CHECK-NEXT: [[TMP9:%.*]] = bitcast [2 x [3 x float]]* [[SORT_TYPED4]] to i8* +// CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds [2 x i8*], [2 x i8*]* [[SORT_TYPED]], i64 0, i64 1 +// CHECK-NEXT: store i8* [[TMP9]], i8** [[TMP10]], align 8 // CHECK-NEXT: ret void // CHECK: sort.in_bounds-true: -// CHECK-NEXT: [[TMP17:%.*]] = mul i64 [[TMP10]], 2 -// CHECK-NEXT: [[TMP18:%.*]] = xor i64 [[TMP17]], 1 -// CHECK-NEXT: [[TMP19:%.*]] = icmp slt i64 [[TMP17]], [[TMP18]] -// CHECK-NEXT: [[TMP20:%.*]] = icmp slt i64 [[TMP18]], 3 -// CHECK-NEXT: [[TMP21:%.*]] = and i1 [[TMP19]], [[TMP20]] -// CHECK-NEXT: br i1 [[TMP21]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] +// CHECK-NEXT: [[TMP11:%.*]] = mul i64 [[TMP4]], 2 +// CHECK-NEXT: [[TMP12:%.*]] = xor i64 [[TMP11]], 1 +// CHECK-NEXT: [[TMP13:%.*]] = icmp slt i64 [[TMP11]], [[TMP12]] +// CHECK-NEXT: [[TMP14:%.*]] = icmp slt i64 [[TMP12]], 3 +// CHECK-NEXT: [[TMP15:%.*]] = and i1 [[TMP13]], [[TMP14]] +// CHECK-NEXT: br i1 [[TMP15]], label [[SMALLER_COMPARISON_INDEX_TRUE:%.*]], label [[SMALLER_COMPARISON_INDEX_AFTER:%.*]] // CHECK: smaller_comparison_index-after: // CHECK-NEXT: br label [[SORT_IN_BOUNDS_AFTER]] // CHECK: smaller_comparison_index-true: -// CHECK-NEXT: [[TMP22:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP18]] -// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP17]] -// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP18]] -// CHECK-NEXT: [[TMP25:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP17]] -// CHECK-NEXT: call void @region_0_6(i32* [[TMP22]], i32* [[TMP23]], float* [[TMP24]], float* [[TMP25]], i8* [[COMPARE_RETURN_BUFFER]]) -// CHECK-NEXT: [[TMP26:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 -// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP26]], 0 +// CHECK-NEXT: [[TMP16:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP12]] +// CHECK-NEXT: [[TMP17:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP11]] +// CHECK-NEXT: [[TMP18:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP12]] +// CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP11]] +// CHECK-NEXT: call void @compare(i32* [[TMP16]], i32* [[TMP17]], float* [[TMP18]], float* [[TMP19]], i8* [[COMPARE_RETURN_BUFFER]]) +// CHECK-NEXT: [[TMP20:%.*]] = load i8, i8* [[COMPARE_RETURN_BUFFER]], align 1 +// CHECK-NEXT: [[BOOLEAN_PREDICATE:%.*]] = icmp ne i8 [[TMP20]], 0 // CHECK-NEXT: br i1 [[BOOLEAN_PREDICATE]], label [[IS_SMALLER_THAN_TRUE:%.*]], label [[IS_SMALLER_THAN_AFTER:%.*]] // CHECK: is_smaller_than-after: // CHECK-NEXT: br label [[SMALLER_COMPARISON_INDEX_AFTER]] // CHECK: is_smaller_than-true: -// CHECK-NEXT: [[TMP27:%.*]] = load i32, i32* [[TMP22]], align 4 -// CHECK-NEXT: [[TMP28:%.*]] = load i32, i32* [[TMP23]], align 4 -// CHECK-NEXT: [[TMP29:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP17]] -// CHECK-NEXT: store i32 [[TMP27]], i32* [[TMP29]], align 4 -// CHECK-NEXT: [[TMP30:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[TMP1]], i64 0, i64 [[TMP11]], i64 [[TMP18]] -// CHECK-NEXT: store i32 [[TMP28]], i32* [[TMP30]], align 4 -// CHECK-NEXT: [[TMP31:%.*]] = load float, float* [[TMP24]], align 4 -// CHECK-NEXT: [[TMP32:%.*]] = load float, float* [[TMP25]], align 4 -// CHECK-NEXT: [[TMP33:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP17]] -// CHECK-NEXT: store float [[TMP31]], float* [[TMP33]], align 4 -// CHECK-NEXT: [[TMP34:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[TMP3]], i64 0, i64 [[TMP11]], i64 [[TMP18]] -// CHECK-NEXT: store float [[TMP32]], float* [[TMP34]], align 4 +// CHECK-NEXT: [[TMP21:%.*]] = load i32, i32* [[TMP16]], align 4 +// CHECK-NEXT: [[TMP22:%.*]] = load i32, i32* [[TMP17]], align 4 +// CHECK-NEXT: [[TMP23:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP11]] +// CHECK-NEXT: store i32 [[TMP21]], i32* [[TMP23]], align 4 +// CHECK-NEXT: [[TMP24:%.*]] = getelementptr inbounds [2 x [3 x i32]], [2 x [3 x i32]]* [[SORT_TYPED2]], i64 0, i64 [[TMP5]], i64 [[TMP12]] +// CHECK-NEXT: store i32 [[TMP22]], i32* [[TMP24]], align 4 +// CHECK-NEXT: [[TMP25:%.*]] = load float, float* [[TMP18]], align 4 +// CHECK-NEXT: [[TMP26:%.*]] = load float, float* [[TMP19]], align 4 +// CHECK-NEXT: [[TMP27:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP11]] +// CHECK-NEXT: store float [[TMP25]], float* [[TMP27]], align 4 +// CHECK-NEXT: [[TMP28:%.*]] = getelementptr inbounds [2 x [3 x float]], [2 x [3 x float]]* [[SORT_TYPED4]], i64 0, i64 [[TMP5]], i64 [[TMP12]] +// CHECK-NEXT: store float [[TMP26]], float* [[TMP28]], align 4 // CHECK-NEXT: br label [[IS_SMALLER_THAN_AFTER]] ENTRY main { x = s32[2, 3] parameter(0) diff --git a/tensorflow/compiler/xla/service/gpu/tests/sorting_test.cc b/tensorflow/compiler/xla/service/gpu/tests/sorting_test.cc deleted file mode 100644 index 197a0c6cfeb..00000000000 --- a/tensorflow/compiler/xla/service/gpu/tests/sorting_test.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include - -#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" -#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/service/hlo_parser.h" -#include "tensorflow/compiler/xla/statusor.h" -#include "tensorflow/compiler/xla/tests/filecheck.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/xla.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/stream_executor/lib/statusor.h" - -namespace xla { -namespace gpu { - -namespace { - -class SortingTest : public GpuCodegenTest { - protected: - HloModuleConfig ConfigWithoutLayoutAssignment() { - HloModuleConfig config; - auto debug_options = HloTestBase::GetDebugOptionsForTest(); - // Disable layout_assignment to use the preassigned layouts. - debug_options.add_xla_disable_hlo_passes("layout-assignment"); - config.set_debug_options(debug_options); - return config; - } -}; - -TEST_F(SortingTest, Regression1) { - const char* hlo_text = R"( -HloModule TestModule - -compare { - p.0.lhs = f32[] parameter(0) - p.0.rhs = f32[] parameter(1) - ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT -} - -ENTRY TestComputation { - x = f32[3, 2]{1, 0} parameter(0) - x.copy = f32[3, 2]{0, 1} copy(x) - ROOT sort = f32[3, 2]{0, 1} sort(x.copy), dimensions={1}, to_apply=compare -} - -)"; - - EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_text, ErrorSpec{1e-5, 1e-5})); -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 2963d546380..b01ae2efe43 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -415,10 +415,9 @@ llvm::Instruction* AddRangeMetadata(int64 lower, int64 upper, return inst; } -string IrName(absl::string_view a) { - std::string s(a); - s.erase(std::remove(s.begin(), s.end(), '%'), s.end()); - return s; +string IrName(string a) { + a.erase(std::remove(a.begin(), a.end(), '%'), a.end()); + return a; } string IrName(absl::string_view a, absl::string_view b) { diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index c0a55e4da33..642965b6470 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -87,7 +87,7 @@ string DumpModuleToString(const llvm::Module& module); // - joining all of the nonempty inputs by '.', and then // - removing all '%'s. // -string IrName(absl::string_view a); +string IrName(string a); string IrName(absl::string_view a, absl::string_view b); string IrName(const HloInstruction* a, absl::string_view b = "");