Implement BitonicSort for GPU.

This is a first version, several things are still missing:
- Support for key/value sorting.
- Support for other types than F32, S32 and U32.
- Parallelization of the inner loop.

PiperOrigin-RevId: 205052657
This commit is contained in:
Adrian Kuegel 2018-07-18 03:10:36 -07:00 committed by TensorFlower Gardener
parent 3a576d3a28
commit b74f7b71fa
4 changed files with 206 additions and 3 deletions

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
namespace xla { namespace xla {
@ -123,9 +124,136 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
return Status::OK(); return Status::OK();
} }
Status IrEmitter::HandleSort(HloInstruction*) { Status IrEmitter::HandleSort(HloInstruction* sort) {
// TODO(b/26783907): Implement sort on GPU. auto keys = sort->operand(0);
return Unimplemented("sort"); auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
if (values != nullptr) {
// TODO(b/26783907): Also sort the values by their corresponding key.
return Unimplemented("Key/Value Sort is not implemented on GPU");
}
int dimension_to_sort = sort->dimensions(0);
const llvm_ir::IrArray& keys_array = GetIrArray(*keys, *sort);
const llvm_ir::IrArray& target_array = GetIrArray(*sort, *sort);
const Shape& keys_shape = keys->shape();
// TODO(b/26783907): This case can probably be avoided with the Algebraic
// Simplifier.
if (ShapeUtil::IsScalar(keys_shape)) {
return Status::OK();
}
// Create loop nests which loop through the operand dimensions. The sort
// dimension is handled in three separate innermost loops which perform the
// sorting.
llvm_ir::ForLoopNest loop_nest(IrName(sort), &ir_builder_);
llvm_ir::IrArray::Index keys_index = EmitOperandArrayLoopNest(
keys_array, dimension_to_sort, "keys", &loop_nest);
// 'compare_keys_index' is the index of the element that 'keys_index' should
// be compared to.
llvm_ir::IrArray::Index compare_keys_index(keys_index.GetType());
for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) {
if (dimension != dimension_to_sort) {
compare_keys_index.push_back(keys_index[dimension]);
} else {
compare_keys_index.push_back(nullptr);
}
}
// Create the sorting loops which do the sorting.
int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
std::unique_ptr<llvm_ir::ForLoop> stages_loop = loop_nest.AddLoop(
/*start_index=*/0,
/*end_index=*/
tensorflow::Log2Ceiling64(dimension_to_sort_bound),
/*suffix=*/"sort_stages");
std::unique_ptr<llvm_ir::ForLoop> mask_loop = loop_nest.AddLoop(
/*suffix=*/"mask",
/*start_index=*/keys_index.GetConstantWithIndexType(0),
/*end_index=*/stages_loop->GetIndVarValue());
std::unique_ptr<llvm_ir::ForLoop> compare_loop = loop_nest.AddLoop(
/*start_index=*/0,
/*end_index=*/dimension_to_sort_bound,
/*suffix=*/"compare");
// Naive C++ code for the inner loops (without parallelization):
//
// for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
// ++stage) {
// int64 first_xor_mask = (1LL << (stage + 1)) - 1;
// for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
// int64 j = i ^ first_xor_mask;
// if (i < j && j < dimension_to_sort_bound) {
// int64 min_key = std::min(keys[i], keys[j]);
// keys[j] = std::max(keys[i], keys[j]);
// keys[i] = min_key;
// }
// }
// for (int64 mask = 0; mask < stage; ++mask) {
// int64 later_xor_mask = (1LL << (stage - (mask + 1));
// for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
// int64 j = i ^ later_xor_mask;
// if (i < j && j < dimension_to_sort_bound) {
// int64 min_key = std::min(keys[i], keys[j]);
// keys[j] = std::max(keys[i], keys[j]);
// keys[i] = min_key;
// }
// }
// }
// }
//
// This follows the algorithm described on Wikipedia:
// https://en.wikipedia.org/wiki/Bitonic_sorter
SetToFirstInsertPoint(stages_loop->GetBodyBasicBlock(), &ir_builder_);
// The first xor mask of a stage is 2^(stage + 1) - 1.
auto first_xor_mask = ir_builder_.CreateSub(
ir_builder_.CreateShl(
keys_index.GetConstantWithIndexType(1),
ir_builder_.CreateAdd(stages_loop->GetIndVarValue(),
keys_index.GetConstantWithIndexType(1))),
keys_index.GetConstantWithIndexType(1));
std::unique_ptr<llvm_ir::ForLoop> first_compare_loop =
llvm_ir::ForLoop::EmitForLoop(
/*prefix=*/"first_compare",
/*start_index=*/keys_index.GetConstantWithIndexType(0),
/*end_index=*/
keys_index.GetConstantWithIndexType(
keys_shape.dimensions(dimension_to_sort)),
/*step=*/keys_index.GetConstantWithIndexType(1),
/*ir_builder=*/&ir_builder_);
SetToFirstInsertPoint(first_compare_loop->GetBodyBasicBlock(), &ir_builder_);
// 'first_compare_loop' iterates through the 'dimension_to_sort'.
keys_index[dimension_to_sort] = first_compare_loop->GetIndVarValue();
compare_keys_index[dimension_to_sort] = ir_builder_.CreateXor(
first_compare_loop->GetIndVarValue(), first_xor_mask);
EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
target_array);
SetToFirstInsertPoint(compare_loop->GetPreheaderBasicBlock(), &ir_builder_);
// The later masks of a stage are 2^(stage - (mask_loop_ind_var + 1)).
auto later_xor_mask = ir_builder_.CreateShl(
keys_index.GetConstantWithIndexType(1),
ir_builder_.CreateSub(
stages_loop->GetIndVarValue(),
ir_builder_.CreateAdd(mask_loop->GetIndVarValue(),
keys_index.GetConstantWithIndexType(1))));
SetToFirstInsertPoint(compare_loop->GetBodyBasicBlock(), &ir_builder_);
// 'compare_loop' iterates through the 'dimension_to_sort'.
keys_index[dimension_to_sort] = compare_loop->GetIndVarValue();
compare_keys_index[dimension_to_sort] =
ir_builder_.CreateXor(compare_loop->GetIndVarValue(), later_xor_mask);
EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
target_array);
// Set the IR builder insert point to the exit basic block of the outer most
// loop. This ensures later instructions are inserted after this loop nest.
ir_builder_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
return Status::OK();
} }
Status IrEmitter::HandleSend(HloInstruction*) { Status IrEmitter::HandleSend(HloInstruction*) {
@ -399,6 +527,44 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
return Status::OK(); return Status::OK();
} }
void IrEmitter::EmitCompareLoop(
int64 dimension_to_sort, const llvm_ir::IrArray::Index& keys_index,
const llvm_ir::IrArray::Index& compare_keys_index,
const llvm_ir::IrArray& keys_array) {
// TODO(b/26783907): parallelize this loop.
// if (is_smaller_index &&
// compare_keys[dimension_to_sort] < dimension_to_sort_bound)
llvm::Value* is_smaller_index = ir_builder_.CreateICmpSLT(
keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]);
int64 dimension_to_sort_bound =
keys_array.GetShape().dimensions(dimension_to_sort);
auto if_data = llvm_ir::EmitIfThenElse(
ir_builder_.CreateAnd(
is_smaller_index,
ir_builder_.CreateICmpSLT(
compare_keys_index[dimension_to_sort],
keys_index.GetConstantWithIndexType(dimension_to_sort_bound))),
"smaller_comparison_index", &ir_builder_, /*emit_else=*/false);
SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
auto key1 = keys_array.EmitReadArrayElement(keys_index, &ir_builder_);
auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, &ir_builder_);
auto key_type = keys_array.GetShape().element_type();
auto comparison =
primitive_util::IsFloatingPointType(key_type)
// TODO(b/26783907): Figure out how to handle NaNs.
? ir_builder_.CreateFCmp(llvm::FCmpInst::FCMP_ULT, key1, key2)
: ir_builder_.CreateICmp(
primitive_util::IsSignedIntegralType(key_type)
? llvm::ICmpInst::ICMP_SLT
: llvm::ICmpInst::ICMP_ULT,
key1, key2);
auto min_key = ir_builder_.CreateSelect(comparison, key1, key2);
auto max_key = ir_builder_.CreateSelect(comparison, key2, key1);
keys_array.EmitWriteArrayElement(keys_index, min_key, &ir_builder_);
keys_array.EmitWriteArrayElement(compare_keys_index, max_key, &ir_builder_);
}
Status IrEmitter::EmitAtomicOperationForNestedComputation( Status IrEmitter::EmitAtomicOperationForNestedComputation(
const HloComputation& computation, llvm::Value* output_address, const HloComputation& computation, llvm::Value* output_address,
llvm::Value* source_address) { llvm::Value* source_address) {

View File

@ -198,6 +198,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm::Value* output_address, llvm::Value* output_address,
llvm::Value* source_address); llvm::Value* source_address);
// A helper method for HandleSort(). It adds the inner comparison loop where
// we compare elements pointed to by 'keys_index' and 'compare_keys_index'.
void EmitCompareLoop(int64 dimension_to_sort,
const llvm_ir::IrArray::Index& keys_index,
const llvm_ir::IrArray::Index& compare_keys_index,
const llvm_ir::IrArray& keys_array);
StatusOr<llvm::Value*> ComputeNestedElement( StatusOr<llvm::Value*> ComputeNestedElement(
const HloComputation& computation, const HloComputation& computation,
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements); tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements);

View File

@ -2046,6 +2046,35 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
return IrEmitter::HandleSelect(select); return IrEmitter::HandleSelect(select);
} }
Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
std::vector<std::unique_ptr<Thunk>> thunks;
auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
if (values != nullptr) {
// TODO(b/26783907): Also sort the values by their corresponding key.
return Unimplemented("Key/Value Sort is not implemented on GPU");
}
// First copy the operand to the output, so that we can sort in-place.
// TODO(b/26783907): Share buffer of output and operand when it is possible.
if (sort->operand(0)->IsConstant()) {
thunks.push_back(MakeUnique<HostToDeviceCopyThunk>(
/*source_address=*/sort->operand(0)->literal().untyped_data(),
/*destination_buffer=*/GetAllocationSlice(*sort),
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
} else {
thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
/*source_address=*/GetAllocationSlice(*sort->operand(0)),
/*destination_buffer=*/GetAllocationSlice(*sort),
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
}
thunks.push_back(
BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
thunk_sequence_->emplace_back(
MakeUnique<SequentialThunk>(std::move(thunks), sort));
return IrEmitter::HandleSort(sort);
}
Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) { Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
thunk_sequence_->push_back( thunk_sequence_->push_back(
BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true)); BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));

View File

@ -77,6 +77,7 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleOutfeed(HloInstruction* outfeed) override;
Status HandleRng(HloInstruction* random) override; Status HandleRng(HloInstruction* random) override;
Status HandleSelect(HloInstruction* select) override; Status HandleSelect(HloInstruction* select) override;
Status HandleSort(HloInstruction* sort) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override; Status HandleTupleSelect(HloInstruction* tuple_select) override;
Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleCrossReplicaSum(HloInstruction* crs) override;
Status HandleAfterAll(HloInstruction* gen_token) override; Status HandleAfterAll(HloInstruction* gen_token) override;