diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index fe83d017f4c..a08b72e3afb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" namespace xla { @@ -123,9 +124,136 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { return Status::OK(); } -Status IrEmitter::HandleSort(HloInstruction*) { - // TODO(b/26783907): Implement sort on GPU. - return Unimplemented("sort"); +Status IrEmitter::HandleSort(HloInstruction* sort) { + auto keys = sort->operand(0); + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + if (values != nullptr) { + // TODO(b/26783907): Also sort the values by their corresponding key. + return Unimplemented("Key/Value Sort is not implemented on GPU"); + } + int dimension_to_sort = sort->dimensions(0); + const llvm_ir::IrArray& keys_array = GetIrArray(*keys, *sort); + const llvm_ir::IrArray& target_array = GetIrArray(*sort, *sort); + + const Shape& keys_shape = keys->shape(); + + // TODO(b/26783907): This case can probably be avoided with the Algebraic + // Simplifier. + if (ShapeUtil::IsScalar(keys_shape)) { + return Status::OK(); + } + + // Create loop nests which loop through the operand dimensions. The sort + // dimension is handled in three separate innermost loops which perform the + // sorting. + llvm_ir::ForLoopNest loop_nest(IrName(sort), &ir_builder_); + llvm_ir::IrArray::Index keys_index = EmitOperandArrayLoopNest( + keys_array, dimension_to_sort, "keys", &loop_nest); + + // 'compare_keys_index' is the index of the element that 'keys_index' should + // be compared to. + llvm_ir::IrArray::Index compare_keys_index(keys_index.GetType()); + for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) { + if (dimension != dimension_to_sort) { + compare_keys_index.push_back(keys_index[dimension]); + } else { + compare_keys_index.push_back(nullptr); + } + } + + // Create the sorting loops which do the sorting. + int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort); + std::unique_ptr stages_loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/ + tensorflow::Log2Ceiling64(dimension_to_sort_bound), + /*suffix=*/"sort_stages"); + std::unique_ptr mask_loop = loop_nest.AddLoop( + /*suffix=*/"mask", + /*start_index=*/keys_index.GetConstantWithIndexType(0), + /*end_index=*/stages_loop->GetIndVarValue()); + std::unique_ptr compare_loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/dimension_to_sort_bound, + /*suffix=*/"compare"); + + // Naive C++ code for the inner loops (without parallelization): + // + // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound); + // ++stage) { + // int64 first_xor_mask = (1LL << (stage + 1)) - 1; + // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { + // int64 j = i ^ first_xor_mask; + // if (i < j && j < dimension_to_sort_bound) { + // int64 min_key = std::min(keys[i], keys[j]); + // keys[j] = std::max(keys[i], keys[j]); + // keys[i] = min_key; + // } + // } + // for (int64 mask = 0; mask < stage; ++mask) { + // int64 later_xor_mask = (1LL << (stage - (mask + 1)); + // for (int64 i = 0; i < dimension_to_sort_bound; ++i) { + // int64 j = i ^ later_xor_mask; + // if (i < j && j < dimension_to_sort_bound) { + // int64 min_key = std::min(keys[i], keys[j]); + // keys[j] = std::max(keys[i], keys[j]); + // keys[i] = min_key; + // } + // } + // } + // } + // + // This follows the algorithm described on Wikipedia: + // https://en.wikipedia.org/wiki/Bitonic_sorter + + SetToFirstInsertPoint(stages_loop->GetBodyBasicBlock(), &ir_builder_); + // The first xor mask of a stage is 2^(stage + 1) - 1. + auto first_xor_mask = ir_builder_.CreateSub( + ir_builder_.CreateShl( + keys_index.GetConstantWithIndexType(1), + ir_builder_.CreateAdd(stages_loop->GetIndVarValue(), + keys_index.GetConstantWithIndexType(1))), + keys_index.GetConstantWithIndexType(1)); + std::unique_ptr first_compare_loop = + llvm_ir::ForLoop::EmitForLoop( + /*prefix=*/"first_compare", + /*start_index=*/keys_index.GetConstantWithIndexType(0), + /*end_index=*/ + keys_index.GetConstantWithIndexType( + keys_shape.dimensions(dimension_to_sort)), + /*step=*/keys_index.GetConstantWithIndexType(1), + /*ir_builder=*/&ir_builder_); + + SetToFirstInsertPoint(first_compare_loop->GetBodyBasicBlock(), &ir_builder_); + // 'first_compare_loop' iterates through the 'dimension_to_sort'. + keys_index[dimension_to_sort] = first_compare_loop->GetIndVarValue(); + compare_keys_index[dimension_to_sort] = ir_builder_.CreateXor( + first_compare_loop->GetIndVarValue(), first_xor_mask); + EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, + target_array); + + SetToFirstInsertPoint(compare_loop->GetPreheaderBasicBlock(), &ir_builder_); + // The later masks of a stage are 2^(stage - (mask_loop_ind_var + 1)). + auto later_xor_mask = ir_builder_.CreateShl( + keys_index.GetConstantWithIndexType(1), + ir_builder_.CreateSub( + stages_loop->GetIndVarValue(), + ir_builder_.CreateAdd(mask_loop->GetIndVarValue(), + keys_index.GetConstantWithIndexType(1)))); + + SetToFirstInsertPoint(compare_loop->GetBodyBasicBlock(), &ir_builder_); + // 'compare_loop' iterates through the 'dimension_to_sort'. + keys_index[dimension_to_sort] = compare_loop->GetIndVarValue(); + compare_keys_index[dimension_to_sort] = + ir_builder_.CreateXor(compare_loop->GetIndVarValue(), later_xor_mask); + EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, + target_array); + + // Set the IR builder insert point to the exit basic block of the outer most + // loop. This ensures later instructions are inserted after this loop nest. + ir_builder_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock()); + + return Status::OK(); } Status IrEmitter::HandleSend(HloInstruction*) { @@ -399,6 +527,44 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, return Status::OK(); } +void IrEmitter::EmitCompareLoop( + int64 dimension_to_sort, const llvm_ir::IrArray::Index& keys_index, + const llvm_ir::IrArray::Index& compare_keys_index, + const llvm_ir::IrArray& keys_array) { + // TODO(b/26783907): parallelize this loop. + + // if (is_smaller_index && + // compare_keys[dimension_to_sort] < dimension_to_sort_bound) + llvm::Value* is_smaller_index = ir_builder_.CreateICmpSLT( + keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]); + int64 dimension_to_sort_bound = + keys_array.GetShape().dimensions(dimension_to_sort); + auto if_data = llvm_ir::EmitIfThenElse( + ir_builder_.CreateAnd( + is_smaller_index, + ir_builder_.CreateICmpSLT( + compare_keys_index[dimension_to_sort], + keys_index.GetConstantWithIndexType(dimension_to_sort_bound))), + "smaller_comparison_index", &ir_builder_, /*emit_else=*/false); + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + auto key1 = keys_array.EmitReadArrayElement(keys_index, &ir_builder_); + auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, &ir_builder_); + auto key_type = keys_array.GetShape().element_type(); + auto comparison = + primitive_util::IsFloatingPointType(key_type) + // TODO(b/26783907): Figure out how to handle NaNs. + ? ir_builder_.CreateFCmp(llvm::FCmpInst::FCMP_ULT, key1, key2) + : ir_builder_.CreateICmp( + primitive_util::IsSignedIntegralType(key_type) + ? llvm::ICmpInst::ICMP_SLT + : llvm::ICmpInst::ICMP_ULT, + key1, key2); + auto min_key = ir_builder_.CreateSelect(comparison, key1, key2); + auto max_key = ir_builder_.CreateSelect(comparison, key2, key1); + keys_array.EmitWriteArrayElement(keys_index, min_key, &ir_builder_); + keys_array.EmitWriteArrayElement(compare_keys_index, max_key, &ir_builder_); +} + Status IrEmitter::EmitAtomicOperationForNestedComputation( const HloComputation& computation, llvm::Value* output_address, llvm::Value* source_address) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index d2dd335f10c..e9ad4a752bb 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -198,6 +198,13 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* output_address, llvm::Value* source_address); + // A helper method for HandleSort(). It adds the inner comparison loop where + // we compare elements pointed to by 'keys_index' and 'compare_keys_index'. + void EmitCompareLoop(int64 dimension_to_sort, + const llvm_ir::IrArray::Index& keys_index, + const llvm_ir::IrArray::Index& compare_keys_index, + const llvm_ir::IrArray& keys_array); + StatusOr ComputeNestedElement( const HloComputation& computation, tensorflow::gtl::ArraySlice parameter_elements); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f2597da4b9d..70a227ca4a7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2046,6 +2046,35 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { return IrEmitter::HandleSelect(select); } +Status IrEmitterUnnested::HandleSort(HloInstruction* sort) { + std::vector> thunks; + auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr; + if (values != nullptr) { + // TODO(b/26783907): Also sort the values by their corresponding key. + return Unimplemented("Key/Value Sort is not implemented on GPU"); + } + + // First copy the operand to the output, so that we can sort in-place. + // TODO(b/26783907): Share buffer of output and operand when it is possible. + if (sort->operand(0)->IsConstant()) { + thunks.push_back(MakeUnique( + /*source_address=*/sort->operand(0)->literal().untyped_data(), + /*destination_buffer=*/GetAllocationSlice(*sort), + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort)); + } else { + thunks.push_back(MakeUnique( + /*source_address=*/GetAllocationSlice(*sort->operand(0)), + /*destination_buffer=*/GetAllocationSlice(*sort), + /*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort)); + } + + thunks.push_back( + BuildKernelThunk(sort, /*implements_whole_instruction=*/false)); + thunk_sequence_->emplace_back( + MakeUnique(std::move(thunks), sort)); + return IrEmitter::HandleSort(sort); +} + Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { thunk_sequence_->push_back( BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index 59547c16d7f..616d8a2206e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -77,6 +77,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleRng(HloInstruction* random) override; Status HandleSelect(HloInstruction* select) override; + Status HandleSort(HloInstruction* sort) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleAfterAll(HloInstruction* gen_token) override;