Implement BitonicSort for GPU.
This is a first version, several things are still missing: - Support for key/value sorting. - Support for other types than F32, S32 and U32. - Parallelization of the inner loop. PiperOrigin-RevId: 205052657
This commit is contained in:
parent
3a576d3a28
commit
b74f7b71fa
@ -44,6 +44,7 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/util.h"
|
#include "tensorflow/compiler/xla/util.h"
|
||||||
#include "tensorflow/compiler/xla/window_util.h"
|
#include "tensorflow/compiler/xla/window_util.h"
|
||||||
|
#include "tensorflow/core/lib/core/bits.h"
|
||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
@ -123,9 +124,136 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleSort(HloInstruction*) {
|
Status IrEmitter::HandleSort(HloInstruction* sort) {
|
||||||
// TODO(b/26783907): Implement sort on GPU.
|
auto keys = sort->operand(0);
|
||||||
return Unimplemented("sort");
|
auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
|
||||||
|
if (values != nullptr) {
|
||||||
|
// TODO(b/26783907): Also sort the values by their corresponding key.
|
||||||
|
return Unimplemented("Key/Value Sort is not implemented on GPU");
|
||||||
|
}
|
||||||
|
int dimension_to_sort = sort->dimensions(0);
|
||||||
|
const llvm_ir::IrArray& keys_array = GetIrArray(*keys, *sort);
|
||||||
|
const llvm_ir::IrArray& target_array = GetIrArray(*sort, *sort);
|
||||||
|
|
||||||
|
const Shape& keys_shape = keys->shape();
|
||||||
|
|
||||||
|
// TODO(b/26783907): This case can probably be avoided with the Algebraic
|
||||||
|
// Simplifier.
|
||||||
|
if (ShapeUtil::IsScalar(keys_shape)) {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create loop nests which loop through the operand dimensions. The sort
|
||||||
|
// dimension is handled in three separate innermost loops which perform the
|
||||||
|
// sorting.
|
||||||
|
llvm_ir::ForLoopNest loop_nest(IrName(sort), &ir_builder_);
|
||||||
|
llvm_ir::IrArray::Index keys_index = EmitOperandArrayLoopNest(
|
||||||
|
keys_array, dimension_to_sort, "keys", &loop_nest);
|
||||||
|
|
||||||
|
// 'compare_keys_index' is the index of the element that 'keys_index' should
|
||||||
|
// be compared to.
|
||||||
|
llvm_ir::IrArray::Index compare_keys_index(keys_index.GetType());
|
||||||
|
for (size_t dimension = 0; dimension < keys_index.size(); ++dimension) {
|
||||||
|
if (dimension != dimension_to_sort) {
|
||||||
|
compare_keys_index.push_back(keys_index[dimension]);
|
||||||
|
} else {
|
||||||
|
compare_keys_index.push_back(nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the sorting loops which do the sorting.
|
||||||
|
int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
|
||||||
|
std::unique_ptr<llvm_ir::ForLoop> stages_loop = loop_nest.AddLoop(
|
||||||
|
/*start_index=*/0,
|
||||||
|
/*end_index=*/
|
||||||
|
tensorflow::Log2Ceiling64(dimension_to_sort_bound),
|
||||||
|
/*suffix=*/"sort_stages");
|
||||||
|
std::unique_ptr<llvm_ir::ForLoop> mask_loop = loop_nest.AddLoop(
|
||||||
|
/*suffix=*/"mask",
|
||||||
|
/*start_index=*/keys_index.GetConstantWithIndexType(0),
|
||||||
|
/*end_index=*/stages_loop->GetIndVarValue());
|
||||||
|
std::unique_ptr<llvm_ir::ForLoop> compare_loop = loop_nest.AddLoop(
|
||||||
|
/*start_index=*/0,
|
||||||
|
/*end_index=*/dimension_to_sort_bound,
|
||||||
|
/*suffix=*/"compare");
|
||||||
|
|
||||||
|
// Naive C++ code for the inner loops (without parallelization):
|
||||||
|
//
|
||||||
|
// for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
|
||||||
|
// ++stage) {
|
||||||
|
// int64 first_xor_mask = (1LL << (stage + 1)) - 1;
|
||||||
|
// for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
|
||||||
|
// int64 j = i ^ first_xor_mask;
|
||||||
|
// if (i < j && j < dimension_to_sort_bound) {
|
||||||
|
// int64 min_key = std::min(keys[i], keys[j]);
|
||||||
|
// keys[j] = std::max(keys[i], keys[j]);
|
||||||
|
// keys[i] = min_key;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// for (int64 mask = 0; mask < stage; ++mask) {
|
||||||
|
// int64 later_xor_mask = (1LL << (stage - (mask + 1));
|
||||||
|
// for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
|
||||||
|
// int64 j = i ^ later_xor_mask;
|
||||||
|
// if (i < j && j < dimension_to_sort_bound) {
|
||||||
|
// int64 min_key = std::min(keys[i], keys[j]);
|
||||||
|
// keys[j] = std::max(keys[i], keys[j]);
|
||||||
|
// keys[i] = min_key;
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// This follows the algorithm described on Wikipedia:
|
||||||
|
// https://en.wikipedia.org/wiki/Bitonic_sorter
|
||||||
|
|
||||||
|
SetToFirstInsertPoint(stages_loop->GetBodyBasicBlock(), &ir_builder_);
|
||||||
|
// The first xor mask of a stage is 2^(stage + 1) - 1.
|
||||||
|
auto first_xor_mask = ir_builder_.CreateSub(
|
||||||
|
ir_builder_.CreateShl(
|
||||||
|
keys_index.GetConstantWithIndexType(1),
|
||||||
|
ir_builder_.CreateAdd(stages_loop->GetIndVarValue(),
|
||||||
|
keys_index.GetConstantWithIndexType(1))),
|
||||||
|
keys_index.GetConstantWithIndexType(1));
|
||||||
|
std::unique_ptr<llvm_ir::ForLoop> first_compare_loop =
|
||||||
|
llvm_ir::ForLoop::EmitForLoop(
|
||||||
|
/*prefix=*/"first_compare",
|
||||||
|
/*start_index=*/keys_index.GetConstantWithIndexType(0),
|
||||||
|
/*end_index=*/
|
||||||
|
keys_index.GetConstantWithIndexType(
|
||||||
|
keys_shape.dimensions(dimension_to_sort)),
|
||||||
|
/*step=*/keys_index.GetConstantWithIndexType(1),
|
||||||
|
/*ir_builder=*/&ir_builder_);
|
||||||
|
|
||||||
|
SetToFirstInsertPoint(first_compare_loop->GetBodyBasicBlock(), &ir_builder_);
|
||||||
|
// 'first_compare_loop' iterates through the 'dimension_to_sort'.
|
||||||
|
keys_index[dimension_to_sort] = first_compare_loop->GetIndVarValue();
|
||||||
|
compare_keys_index[dimension_to_sort] = ir_builder_.CreateXor(
|
||||||
|
first_compare_loop->GetIndVarValue(), first_xor_mask);
|
||||||
|
EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
|
||||||
|
target_array);
|
||||||
|
|
||||||
|
SetToFirstInsertPoint(compare_loop->GetPreheaderBasicBlock(), &ir_builder_);
|
||||||
|
// The later masks of a stage are 2^(stage - (mask_loop_ind_var + 1)).
|
||||||
|
auto later_xor_mask = ir_builder_.CreateShl(
|
||||||
|
keys_index.GetConstantWithIndexType(1),
|
||||||
|
ir_builder_.CreateSub(
|
||||||
|
stages_loop->GetIndVarValue(),
|
||||||
|
ir_builder_.CreateAdd(mask_loop->GetIndVarValue(),
|
||||||
|
keys_index.GetConstantWithIndexType(1))));
|
||||||
|
|
||||||
|
SetToFirstInsertPoint(compare_loop->GetBodyBasicBlock(), &ir_builder_);
|
||||||
|
// 'compare_loop' iterates through the 'dimension_to_sort'.
|
||||||
|
keys_index[dimension_to_sort] = compare_loop->GetIndVarValue();
|
||||||
|
compare_keys_index[dimension_to_sort] =
|
||||||
|
ir_builder_.CreateXor(compare_loop->GetIndVarValue(), later_xor_mask);
|
||||||
|
EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
|
||||||
|
target_array);
|
||||||
|
|
||||||
|
// Set the IR builder insert point to the exit basic block of the outer most
|
||||||
|
// loop. This ensures later instructions are inserted after this loop nest.
|
||||||
|
ir_builder_.SetInsertPoint(loop_nest.GetOuterLoopExitBasicBlock());
|
||||||
|
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status IrEmitter::HandleSend(HloInstruction*) {
|
Status IrEmitter::HandleSend(HloInstruction*) {
|
||||||
@ -399,6 +527,44 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void IrEmitter::EmitCompareLoop(
|
||||||
|
int64 dimension_to_sort, const llvm_ir::IrArray::Index& keys_index,
|
||||||
|
const llvm_ir::IrArray::Index& compare_keys_index,
|
||||||
|
const llvm_ir::IrArray& keys_array) {
|
||||||
|
// TODO(b/26783907): parallelize this loop.
|
||||||
|
|
||||||
|
// if (is_smaller_index &&
|
||||||
|
// compare_keys[dimension_to_sort] < dimension_to_sort_bound)
|
||||||
|
llvm::Value* is_smaller_index = ir_builder_.CreateICmpSLT(
|
||||||
|
keys_index[dimension_to_sort], compare_keys_index[dimension_to_sort]);
|
||||||
|
int64 dimension_to_sort_bound =
|
||||||
|
keys_array.GetShape().dimensions(dimension_to_sort);
|
||||||
|
auto if_data = llvm_ir::EmitIfThenElse(
|
||||||
|
ir_builder_.CreateAnd(
|
||||||
|
is_smaller_index,
|
||||||
|
ir_builder_.CreateICmpSLT(
|
||||||
|
compare_keys_index[dimension_to_sort],
|
||||||
|
keys_index.GetConstantWithIndexType(dimension_to_sort_bound))),
|
||||||
|
"smaller_comparison_index", &ir_builder_, /*emit_else=*/false);
|
||||||
|
SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
|
||||||
|
auto key1 = keys_array.EmitReadArrayElement(keys_index, &ir_builder_);
|
||||||
|
auto key2 = keys_array.EmitReadArrayElement(compare_keys_index, &ir_builder_);
|
||||||
|
auto key_type = keys_array.GetShape().element_type();
|
||||||
|
auto comparison =
|
||||||
|
primitive_util::IsFloatingPointType(key_type)
|
||||||
|
// TODO(b/26783907): Figure out how to handle NaNs.
|
||||||
|
? ir_builder_.CreateFCmp(llvm::FCmpInst::FCMP_ULT, key1, key2)
|
||||||
|
: ir_builder_.CreateICmp(
|
||||||
|
primitive_util::IsSignedIntegralType(key_type)
|
||||||
|
? llvm::ICmpInst::ICMP_SLT
|
||||||
|
: llvm::ICmpInst::ICMP_ULT,
|
||||||
|
key1, key2);
|
||||||
|
auto min_key = ir_builder_.CreateSelect(comparison, key1, key2);
|
||||||
|
auto max_key = ir_builder_.CreateSelect(comparison, key2, key1);
|
||||||
|
keys_array.EmitWriteArrayElement(keys_index, min_key, &ir_builder_);
|
||||||
|
keys_array.EmitWriteArrayElement(compare_keys_index, max_key, &ir_builder_);
|
||||||
|
}
|
||||||
|
|
||||||
Status IrEmitter::EmitAtomicOperationForNestedComputation(
|
Status IrEmitter::EmitAtomicOperationForNestedComputation(
|
||||||
const HloComputation& computation, llvm::Value* output_address,
|
const HloComputation& computation, llvm::Value* output_address,
|
||||||
llvm::Value* source_address) {
|
llvm::Value* source_address) {
|
||||||
|
@ -198,6 +198,13 @@ class IrEmitter : public DfsHloVisitorWithDefault {
|
|||||||
llvm::Value* output_address,
|
llvm::Value* output_address,
|
||||||
llvm::Value* source_address);
|
llvm::Value* source_address);
|
||||||
|
|
||||||
|
// A helper method for HandleSort(). It adds the inner comparison loop where
|
||||||
|
// we compare elements pointed to by 'keys_index' and 'compare_keys_index'.
|
||||||
|
void EmitCompareLoop(int64 dimension_to_sort,
|
||||||
|
const llvm_ir::IrArray::Index& keys_index,
|
||||||
|
const llvm_ir::IrArray::Index& compare_keys_index,
|
||||||
|
const llvm_ir::IrArray& keys_array);
|
||||||
|
|
||||||
StatusOr<llvm::Value*> ComputeNestedElement(
|
StatusOr<llvm::Value*> ComputeNestedElement(
|
||||||
const HloComputation& computation,
|
const HloComputation& computation,
|
||||||
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements);
|
tensorflow::gtl::ArraySlice<llvm::Value*> parameter_elements);
|
||||||
|
@ -2046,6 +2046,35 @@ Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
|
|||||||
return IrEmitter::HandleSelect(select);
|
return IrEmitter::HandleSelect(select);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||||
|
std::vector<std::unique_ptr<Thunk>> thunks;
|
||||||
|
auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
|
||||||
|
if (values != nullptr) {
|
||||||
|
// TODO(b/26783907): Also sort the values by their corresponding key.
|
||||||
|
return Unimplemented("Key/Value Sort is not implemented on GPU");
|
||||||
|
}
|
||||||
|
|
||||||
|
// First copy the operand to the output, so that we can sort in-place.
|
||||||
|
// TODO(b/26783907): Share buffer of output and operand when it is possible.
|
||||||
|
if (sort->operand(0)->IsConstant()) {
|
||||||
|
thunks.push_back(MakeUnique<HostToDeviceCopyThunk>(
|
||||||
|
/*source_address=*/sort->operand(0)->literal().untyped_data(),
|
||||||
|
/*destination_buffer=*/GetAllocationSlice(*sort),
|
||||||
|
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
|
||||||
|
} else {
|
||||||
|
thunks.push_back(MakeUnique<DeviceToDeviceCopyThunk>(
|
||||||
|
/*source_address=*/GetAllocationSlice(*sort->operand(0)),
|
||||||
|
/*destination_buffer=*/GetAllocationSlice(*sort),
|
||||||
|
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
|
||||||
|
}
|
||||||
|
|
||||||
|
thunks.push_back(
|
||||||
|
BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
|
||||||
|
thunk_sequence_->emplace_back(
|
||||||
|
MakeUnique<SequentialThunk>(std::move(thunks), sort));
|
||||||
|
return IrEmitter::HandleSort(sort);
|
||||||
|
}
|
||||||
|
|
||||||
Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
|
Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
|
||||||
thunk_sequence_->push_back(
|
thunk_sequence_->push_back(
|
||||||
BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
|
BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
|
||||||
|
@ -77,6 +77,7 @@ class IrEmitterUnnested : public IrEmitter {
|
|||||||
Status HandleOutfeed(HloInstruction* outfeed) override;
|
Status HandleOutfeed(HloInstruction* outfeed) override;
|
||||||
Status HandleRng(HloInstruction* random) override;
|
Status HandleRng(HloInstruction* random) override;
|
||||||
Status HandleSelect(HloInstruction* select) override;
|
Status HandleSelect(HloInstruction* select) override;
|
||||||
|
Status HandleSort(HloInstruction* sort) override;
|
||||||
Status HandleTupleSelect(HloInstruction* tuple_select) override;
|
Status HandleTupleSelect(HloInstruction* tuple_select) override;
|
||||||
Status HandleCrossReplicaSum(HloInstruction* crs) override;
|
Status HandleCrossReplicaSum(HloInstruction* crs) override;
|
||||||
Status HandleAfterAll(HloInstruction* gen_token) override;
|
Status HandleAfterAll(HloInstruction* gen_token) override;
|
||||||
|
Loading…
Reference in New Issue
Block a user