Implement BitonicSort for GPU.
This is a first version, several things are still missing: - Support for key/value sorting. - Support for other types than F32, S32 and U32. - Parallelization of the inner loop. PiperOrigin-RevId: 205052657
This commit is contained in:
parent
3a576d3a28
commit
b74f7b71fa
@ -44,6 +44,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/core/lib/core/bits.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
namespace xla {
|
||||
@ -123,9 +124,136 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleSort(HloInstruction*) {
|
||||
// TODO(b/26783907): Implement sort on GPU.
|
||||
return Unimplemented("sort");
|
||||
Status IrEmitter::HandleSort(HloInstruction* sort) {
|
||||
auto keys = sort->operand(0);
|
||||
auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
|
||||
if (values != nullptr) {
|
||||
// TODO(b/26783907): Also sort the values by their corresponding key.
|
||||
return Unimplemented("Key/Value Sort is not implemented on GPU");
|
||||
}
|
||||
int dimension_to_sort = sort->dimensions(0);
|
||||
const llvm_ir::IrArray& keys_array = GetIrArray(*keys, *sort);
|
||||
const llvm_ir::IrArray& target_array = GetIrArray(*sort, *sort);
|
||||
|
||||
const Shape& keys_shape = keys->shape();
|
||||
|
||||
// TODO(b/26783907): This case can probably be avoided with the Algebraic
|
||||
// Simplifier.
|
||||
if (ShapeUtil::IsScalar(keys_shape)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create loop nests which loop through the operand dimensions. The sort
|
||||
// dimension is handled in three separate innermost loops which perform the
|
||||
// sorting.
|
||||
llvm_ir::ForLoopNest loop_nest(IrName(sort), &ir_builder_);
|
||||
llvm_ir::IrArray::Index keys_index = EmitOperandArrayLoopNest(
|
||||
keys_array, dimension_to_sort, "keys", &loop_nest);
|
||||
|
||||
// 'compare_keys_index' is the index of the element that 'keys_index' should
|
||||
// be compared to.
|
||||
llvm_ir::IrArray::Index compare_keys_index(keys_index.GetType());
|
||||
for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) {
|
||||
if (dimension != dimension_to_sort) {
|
||||
compare_keys_index.push_back(keys_index[dimension]);
|
||||
} else {
|
||||
compare_keys_index.push_back(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Create the sorting loops which do the sorting.
|
||||
int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
|
||||
std::unique_ptr<llvm_ir::ForLoop> stages_loop = loop_nest.AddLoop(
|
||||
/*start_index=*/0,
|
||||
/*end_index=*/
|
||||
tensorflow::Log2Ceiling64(dimension_to_sort_bound),
|
||||
/*suffix=*/"sort_stages");
|
||||
std::unique_ptr<llvm_ir::ForLoop> mask_loop = loop_nest.AddLoop(
|
||||
/*suffix=*/"mask",
|
||||
/*start_index=*/keys_index.GetConstantWithIndexType(0),
|
||||
/*end_index=*/stages_loop->GetIndVarValue());
|
||||
std::unique_ptr<llvm_ir::ForLoop> compare_loop = loop_nest.AddLoop(
|
||||
/*start_index=*/0,
|
||||
/*end_index=*/dimension_to_sort_bound,
|
||||
/*suffix=*/"compare");
|
||||
|
||||
// Naive C++ code for the inner loops (without parallelization):
|
||||
//
|
||||
// for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
|
||||
// ++stage) {
|
||||
// int64 first_xor_mask = (1LL << (stage + 1)) - 1;
|
||||
// for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
|
||||
// int64 j = i ^ first_xor_mask;
|
||||
// if (i < j && j < dimension_to_sort_bound) {
|
||||
// int64 min_key = std::min(keys[i], keys[j]);
|
||||
// keys[j] = std::max(keys[i], keys[j]);
|
||||
// keys[i] = min_key;
|
||||
// }
|
||||
// }
|
||||
// for (int64 mask = 0; mask < stage; ++mask) {
|
||||
// int64 later_xor_mask = (1LL << (stage - (mask + 1));
|
||||
// for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
|
||||
// int64 j = i ^ later_xor_mask;
|
||||
// if (i < j && j < dimension_to_sort_bound) {
|
||||
// int64 min_key = std::min(keys[i], keys[j]);
|
||||
// keys[j] = std::max(keys[i], keys[j]);
|
||||
// keys[i] = min_key;
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// This follows the algorithm described on Wikipedia:
|
||||
// https://en.wikipedia.org/wiki/Bitonic_sorter
|
||||
|
||||
SetToFirstInsertPoint(stages_loop->GetBodyBasicBlock(), &ir_builder_);
|
||||
// The first xor mask of a stage is 2^(stage + 1) - 1.
|
||||
auto first_xor_mask = ir_builder_.CreateSub(
|
||||
ir_builder_.CreateShl(
|
||||
keys_index.GetConstantWithIndexType(1),
|
||||
ir_builder_.CreateAdd(stages_loop->GetIndVarValue(),
|
||||
keys_index.GetConstantWithIndexType(1))),
|
||||
keys_index.GetConstantWithIndexType(1));
|
||||
std::unique_ptr<llvm_ir::ForLoop> first_compare_loop =
|
||||
llvm_ir::ForLoop::EmitForLoop(
|
||||
/*prefix=*/"first_compare",
|
||||
/*start_index=*/keys_index.GetConstantWithIndexType(0),
|
||||
/*end_index=*/
|
||||
keys_index.GetConstantWithIndexType(
|
||||
keys_shape.dimensions(dimension_to_sort)),
|
||||
/*step=*/keys_index.GetConstantWithIndexType(1),
|
||||
/*ir_builder=*/&ir_builder_);
|
||||
|
||||
SetToFirstInsertPoint(first_compare_loop->GetBodyBasicBlock(), &ir_builder_);
|
||||
// 'first_compare_loop' iterates through the 'dimension_to_sort'.
|
||||
keys_index[dimension_to_sort] = first_compare_loop->GetIndVarValue();
|
||||
compare_keys_index[dimension_to_sort] = ir_builder_.CreateXor(
|
||||
first_compare_loop->GetIndVarValue(), first_xor_mask);
|
||||
EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
|
||||
target_array);
|
||||
|
||||
SetToFirstInsertPoint(compare_loop->GetPreheaderBasicBlock(), &ir_builder_);
|
||||
// The later masks of a stage are 2^(stage - (mask_loop_ind_var + 1)).
|
||||
auto later_xor_mask = ir_builder_.CreateShl(
|
||||
keys_index.GetConstantWithIndexType(1),
|
||||
ir_builder_.CreateSub(
|
||||
stages_loop->GetIndVarValue(),
|
||||
ir_builder_.CreateAdd(mask_loop->GetIndVarValue(),
|
||||
keys_index.GetConstantWithIndexType(1))));
|
||||
|
||||
SetToFirstInsertPoint(compare_loop->GetBodyBasicBlock(), &ir_builder_);
|
||||
// 'compare_loop' iterates through the 'dimension_to_sort'.
|
||||
keys_index[dimension_to_sort] = compare_loop->GetIndVarValue();
|
||||
compare_keys_index[dimension_to_sort] =
|
||||
ir_builder_.CreateXor(compare_loop->GetIndVarValue(), later_xor_mask);
|
||||
EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
|
||||
target_array);
|
||||
|
||||
// Set the IR builder insert point to the exit basic block of the outer most
|
||||
// loop. This ensures later instructions are inserted after this loop nest.
|
||||
ir_builder_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleSend(HloInstruction*) {
|
||||
@ -399,6 +527,44 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void IrEmitter::EmitCompareLoop(
|
||||
int64 dimension_to_sort, const llvm_ir::IrArray::Index& keys_index,
|
||||
const llvm_ir::IrArray::Index& compare_keys_index,
|
||||
const llvm_ir::IrArray& keys_array) {
|
||||
// TODO(b/26783907): parallelize this loop.
|
||||
|
||||
// if (is_smaller_index &&
|
||||
// compare_keys[dimension_to_sort] < dimension_to_sort_bound)
|
||||
llvm::Value* is_smaller_index = ir_builder_.CreateICmpSLT(
|
||||
keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]);
|
||||
int64 dimension_to_sort_bound =
|
||||
keys_array.GetShape().dimensions(dimension_to_sort);
|
||||
auto if_data = llvm_ir::EmitIfThenElse(
|
||||
ir_builder_.CreateAnd(
|
||||
is_smaller_index,
|
||||
ir_builder_.CreateICmpSLT(
|
||||
compare_keys_index[dimension_to_sort],
|
||||
keys_index.GetConstantWithIndexType(dimension_to_sort_bound))),
|
||||
"smaller_comparison_index", &ir_builder_, /*emit_else=*/false);
|
||||
SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
|
||||
auto key1 = keys_array.EmitReadArrayElement(keys_index, &ir_builder_);
|
||||
auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, &ir_builder_);
|
||||
auto key_type = keys_array.GetShape().element_type();
|
||||
auto comparison =
|
||||
primitive_util::IsFloatingPointType(key_type)
|
||||
// TODO(b/26783907): Figure out how to handle NaNs.
|
||||
? ir_builder_.CreateFCmp(llvm::FCmpInst::FCMP_ULT, key1, key2)
|
||||
: ir_builder_.CreateICmp(
|
||||
primitive_util::IsSignedIntegralType(key_type)
|
||||
? llvm::ICmpInst::ICMP_SLT
|
||||
: llvm::ICmpInst::ICMP_ULT,
|
||||
key1, key2);
|
||||
auto min_key = ir_builder_.CreateSelect(comparison, key1, key2);
|
||||
auto max_key = ir_builder_.CreateSelect(comparison, key2, key1);
|
||||
keys_array.EmitWriteArrayElement(keys_index, min_key, &ir_builder_);
|
||||
keys_array.EmitWriteArrayElement(compare_keys_index, max_key, &ir_builder_);
|
||||
}
|
||||
|
||||
Status IrEmitter::EmitAtomicOperationForNestedComputation(
|
||||
const HloComputation& computation, llvm::Value* output_address,
|
||||
llvm::Value* source_address) {
|
||||
|
@ -198,6 +198,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
|
||||
llvm::Value* output_address,
|
||||
llvm::Value* source_address);
|
||||
|
||||
// A helper method for HandleSort(). It adds the inner comparison loop where
|
||||
// we compare elements pointed to by 'keys_index' and 'compare_keys_index'.
|
||||
void EmitCompareLoop(int64 dimension_to_sort,
|
||||
const llvm_ir::IrArray::Index& keys_index,
|
||||
const llvm_ir::IrArray::Index& compare_keys_index,
|
||||
const llvm_ir::IrArray& keys_array);
|
||||
|
||||
StatusOr<llvm::Value*> ComputeNestedElement(
|
||||
const HloComputation& computation,
|
||||
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements);
|
||||
|
@ -2046,6 +2046,35 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
|
||||
return IrEmitter::HandleSelect(select);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||
auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
|
||||
if (values != nullptr) {
|
||||
// TODO(b/26783907): Also sort the values by their corresponding key.
|
||||
return Unimplemented("Key/Value Sort is not implemented on GPU");
|
||||
}
|
||||
|
||||
// First copy the operand to the output, so that we can sort in-place.
|
||||
// TODO(b/26783907): Share buffer of output and operand when it is possible.
|
||||
if (sort->operand(0)->IsConstant()) {
|
||||
thunks.push_back(MakeUnique<HostToDeviceCopyThunk>(
|
||||
/*source_address=*/sort->operand(0)->literal().untyped_data(),
|
||||
/*destination_buffer=*/GetAllocationSlice(*sort),
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
|
||||
} else {
|
||||
thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
|
||||
/*source_address=*/GetAllocationSlice(*sort->operand(0)),
|
||||
/*destination_buffer=*/GetAllocationSlice(*sort),
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
|
||||
}
|
||||
|
||||
thunks.push_back(
|
||||
BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
|
||||
thunk_sequence_->emplace_back(
|
||||
MakeUnique<SequentialThunk>(std::move(thunks), sort));
|
||||
return IrEmitter::HandleSort(sort);
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
|
||||
thunk_sequence_->push_back(
|
||||
BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
|
||||
|
@ -77,6 +77,7 @@ class IrEmitterUnnested : public IrEmitter {
|
||||
Status HandleOutfeed(HloInstruction* outfeed) override;
|
||||
Status HandleRng(HloInstruction* random) override;
|
||||
Status HandleSelect(HloInstruction* select) override;
|
||||
Status HandleSort(HloInstruction* sort) override;
|
||||
Status HandleTupleSelect(HloInstruction* tuple_select) override;
|
||||
Status HandleCrossReplicaSum(HloInstruction* crs) override;
|
||||
Status HandleAfterAll(HloInstruction* gen_token) override;
|
||||
|
Loading…
Reference in New Issue
Block a user