Parallelize BitonicSort on GPU.
We now emit O(log^n) kernel thunks. Each thunk is responsible for looping over the other dimensions, and then doing a comparison loop through the dimension that should be sorted. PiperOrigin-RevId: 205791397
This commit is contained in:
parent
fca1561b9d
commit
33035bb79b
@ -37,7 +37,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
|
||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -123,17 +122,6 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleSort(HloInstruction* sort) {
|
||||
auto values = sort->operand_count() > 1 ? sort->operand(1) : nullptr;
|
||||
if (values != nullptr) {
|
||||
// TODO(b/26783907): Also sort the values by their corresponding key.
|
||||
return Unimplemented("Key/Value Sort is not implemented on GPU");
|
||||
}
|
||||
int dimension_to_sort = sort->dimensions(0);
|
||||
return llvm_ir::EmitSortInPlace(dimension_to_sort, GetIrArray(*sort, *sort),
|
||||
IrName(sort), &b_);
|
||||
}
|
||||
|
||||
Status IrEmitter::HandleSend(HloInstruction*) {
|
||||
return Unimplemented("Send is not implemented on GPU");
|
||||
}
|
||||
|
@ -79,7 +79,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
|
||||
Status HandleCrossReplicaSum(HloInstruction* crs) override;
|
||||
Status HandleInfeed(HloInstruction* infeed) override;
|
||||
Status HandleOutfeed(HloInstruction* outfeed) override;
|
||||
Status HandleSort(HloInstruction* sort) override;
|
||||
Status HandleSend(HloInstruction* send) override;
|
||||
Status HandleSendDone(HloInstruction* send_done) override;
|
||||
Status HandleRecv(HloInstruction* recv) override;
|
||||
|
@ -63,6 +63,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
|
||||
#include "tensorflow/compiler/xla/service/name_uniquer.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -71,6 +72,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/window_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/bits.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
@ -2036,11 +2038,51 @@ Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
|
||||
/*mem_size=*/ShapeUtil::ByteSizeOf(sort->shape()), sort));
|
||||
}
|
||||
|
||||
thunks.push_back(
|
||||
BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
|
||||
int64 dimension_to_sort = sort->dimensions(0);
|
||||
int64 dimension_to_sort_bound = sort->shape().dimensions(dimension_to_sort);
|
||||
int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
|
||||
auto index_type = b_.getInt64Ty();
|
||||
|
||||
// Naive C++ code for the outer loops:
|
||||
//
|
||||
// for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
|
||||
// ++stage) {
|
||||
// int64 first_xor_mask = (1LL << (stage + 1)) - 1;
|
||||
// SortInPlace(first_xor_mask);
|
||||
// for (int64 mask = stage - 1; mask >= 0; --mask) {
|
||||
// int64 later_xor_mask = 1LL << mask;
|
||||
// SortInPlace(later_xor_mask);
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// This follows the algorithm described on Wikipedia:
|
||||
// https://en.wikipedia.org/wiki/Bitonic_sorter
|
||||
|
||||
for (int64 stage = 0; stage < num_stages; ++stage) {
|
||||
for (int64 mask = stage; mask >= 0; --mask) {
|
||||
thunks.push_back(
|
||||
BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
|
||||
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
|
||||
sort->shape(), ir_emitter_context_->device_description());
|
||||
UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
|
||||
ir_emitter_context_->llvm_module());
|
||||
|
||||
llvm::Value* xor_mask;
|
||||
if (mask == stage) {
|
||||
xor_mask = llvm::ConstantInt::get(index_type, (1LL << (stage + 1)) - 1);
|
||||
} else {
|
||||
xor_mask = llvm::ConstantInt::get(index_type, 1LL << mask);
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(llvm_ir::EmitSortInPlace(
|
||||
dimension_to_sort, GetIrArray(*sort, *sort), IrName(sort), xor_mask,
|
||||
&b_, &launch_dimensions));
|
||||
}
|
||||
}
|
||||
|
||||
thunk_sequence_->emplace_back(
|
||||
MakeUnique<SequentialThunk>(std::move(thunks), sort));
|
||||
return IrEmitter::HandleSort(sort);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
|
||||
|
@ -188,7 +188,10 @@ cc_library(
|
||||
":ir_array",
|
||||
":llvm_loop",
|
||||
":llvm_util",
|
||||
":loop_emitter",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter",
|
||||
"//tensorflow/compiler/xla/service/gpu:partition_assignment",
|
||||
"//tensorflow/core:lib",
|
||||
"@llvm//:core",
|
||||
],
|
||||
|
@ -19,12 +19,15 @@ limitations under the License.
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/Constants.h"
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/core/lib/core/bits.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
@ -73,7 +76,9 @@ void EmitCompareLoop(int64 dimension_to_sort,
|
||||
} // namespace
|
||||
|
||||
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
|
||||
tensorflow::StringPiece name, llvm::IRBuilder<>* b) {
|
||||
tensorflow::StringPiece name, llvm::Value* xor_mask,
|
||||
llvm::IRBuilder<>* b,
|
||||
const gpu::LaunchDimensions* launch_dimensions) {
|
||||
const Shape& keys_shape = keys_array.GetShape();
|
||||
|
||||
// TODO(b/26783907): This case can probably be avoided with the Algebraic
|
||||
@ -83,11 +88,13 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
|
||||
}
|
||||
|
||||
// Create loop nests which loop through the operand dimensions. The sort
|
||||
// dimension is handled in three separate innermost loops which perform the
|
||||
// sorting.
|
||||
// dimension is handled in the innermost loop which performs the sorting.
|
||||
ForLoopNest loop_nest(name, b);
|
||||
IrArray::Index keys_index =
|
||||
loop_nest.EmitOperandArrayLoopNest(keys_array, dimension_to_sort, "keys");
|
||||
if (loop_nest.GetInnerLoopBodyBasicBlock() != nullptr) {
|
||||
SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), b);
|
||||
}
|
||||
|
||||
// 'compare_keys_index' is the index of the element that 'keys_index' should
|
||||
// be compared to.
|
||||
@ -100,89 +107,42 @@ Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
|
||||
}
|
||||
}
|
||||
|
||||
// Create the sorting loops which do the sorting.
|
||||
int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
|
||||
std::unique_ptr<ForLoop> stages_loop = loop_nest.AddLoop(
|
||||
/*start_index=*/0,
|
||||
/*end_index=*/
|
||||
tensorflow::Log2Ceiling64(dimension_to_sort_bound),
|
||||
/*suffix=*/"sort_stages");
|
||||
std::unique_ptr<ForLoop> mask_loop = loop_nest.AddLoop(
|
||||
/*suffix=*/"mask",
|
||||
/*start_index=*/keys_index.GetConstantWithIndexType(0),
|
||||
/*end_index=*/stages_loop->GetIndVarValue());
|
||||
std::unique_ptr<ForLoop> compare_loop = loop_nest.AddLoop(
|
||||
/*start_index=*/0,
|
||||
/*end_index=*/dimension_to_sort_bound,
|
||||
/*suffix=*/"compare");
|
||||
|
||||
// Naive C++ code for the inner loops (without parallelization):
|
||||
// Naive C++ code for the inner compare loop:
|
||||
//
|
||||
// for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
|
||||
// ++stage) {
|
||||
// int64 first_xor_mask = (1LL << (stage + 1)) - 1;
|
||||
// for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
|
||||
// int64 j = i ^ first_xor_mask;
|
||||
// if (i < j && j < dimension_to_sort_bound) {
|
||||
// int64 min_key = std::min(keys[i], keys[j]);
|
||||
// keys[j] = std::max(keys[i], keys[j]);
|
||||
// keys[i] = min_key;
|
||||
// }
|
||||
// }
|
||||
// for (int64 mask = 0; mask < stage; ++mask) {
|
||||
// int64 later_xor_mask = (1LL << (stage - (mask + 1));
|
||||
// for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
|
||||
// int64 j = i ^ later_xor_mask;
|
||||
// if (i < j && j < dimension_to_sort_bound) {
|
||||
// int64 min_key = std::min(keys[i], keys[j]);
|
||||
// keys[j] = std::max(keys[i], keys[j]);
|
||||
// keys[i] = min_key;
|
||||
// }
|
||||
// }
|
||||
// for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
|
||||
// int64 j = i ^ xor_mask;
|
||||
// if (i < j && j < dimension_to_sort_bound) {
|
||||
// int64 min_key = std::min(keys[i], keys[j]);
|
||||
// keys[j] = std::max(keys[i], keys[j]);
|
||||
// keys[i] = min_key;
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// This follows the algorithm described on Wikipedia:
|
||||
// https://en.wikipedia.org/wiki/Bitonic_sorter
|
||||
|
||||
SetToFirstInsertPoint(stages_loop->GetBodyBasicBlock(), b);
|
||||
// The first xor mask of a stage is 2^(stage + 1) - 1.
|
||||
auto first_xor_mask = b->CreateSub(
|
||||
b->CreateShl(keys_index.GetConstantWithIndexType(1),
|
||||
b->CreateAdd(stages_loop->GetIndVarValue(),
|
||||
keys_index.GetConstantWithIndexType(1))),
|
||||
keys_index.GetConstantWithIndexType(1));
|
||||
std::unique_ptr<ForLoop> first_compare_loop = ForLoop::EmitForLoop(
|
||||
/*prefix=*/"first_compare",
|
||||
/*start_index=*/keys_index.GetConstantWithIndexType(0),
|
||||
/*end_index=*/
|
||||
keys_index.GetConstantWithIndexType(dimension_to_sort_bound),
|
||||
/*step=*/keys_index.GetConstantWithIndexType(1),
|
||||
/*b=*/b);
|
||||
|
||||
SetToFirstInsertPoint(first_compare_loop->GetBodyBasicBlock(), b);
|
||||
// 'first_compare_loop' iterates through the 'dimension_to_sort'.
|
||||
keys_index[dimension_to_sort] = first_compare_loop->GetIndVarValue();
|
||||
compare_keys_index[dimension_to_sort] =
|
||||
b->CreateXor(first_compare_loop->GetIndVarValue(), first_xor_mask);
|
||||
EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, keys_array,
|
||||
b);
|
||||
|
||||
SetToFirstInsertPoint(compare_loop->GetPreheaderBasicBlock(), b);
|
||||
// The later masks of a stage are 2^(stage - (mask_loop_ind_var + 1)).
|
||||
auto later_xor_mask = b->CreateShl(
|
||||
keys_index.GetConstantWithIndexType(1),
|
||||
b->CreateSub(stages_loop->GetIndVarValue(),
|
||||
b->CreateAdd(mask_loop->GetIndVarValue(),
|
||||
keys_index.GetConstantWithIndexType(1))));
|
||||
|
||||
SetToFirstInsertPoint(compare_loop->GetBodyBasicBlock(), b);
|
||||
// 'compare_loop' iterates through the 'dimension_to_sort'.
|
||||
keys_index[dimension_to_sort] = compare_loop->GetIndVarValue();
|
||||
compare_keys_index[dimension_to_sort] =
|
||||
b->CreateXor(compare_loop->GetIndVarValue(), later_xor_mask);
|
||||
EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index, keys_array,
|
||||
b);
|
||||
int64 dimension_to_sort_bound =
|
||||
keys_array.GetShape().dimensions(dimension_to_sort);
|
||||
Shape compare_shape = ShapeUtil::MakeShape(keys_shape.element_type(),
|
||||
{dimension_to_sort_bound});
|
||||
auto compare_loop_body_emitter =
|
||||
[&](const IrArray::Index& compare_index) -> Status {
|
||||
keys_index[dimension_to_sort] = compare_index[0];
|
||||
compare_keys_index[dimension_to_sort] =
|
||||
b->CreateXor(compare_index[0], xor_mask);
|
||||
EmitCompareLoop(dimension_to_sort, keys_index, compare_keys_index,
|
||||
keys_array, b);
|
||||
return Status::OK();
|
||||
};
|
||||
if (launch_dimensions != nullptr) {
|
||||
TF_RETURN_IF_ERROR(gpu::ParallelLoopEmitter(compare_loop_body_emitter,
|
||||
compare_shape,
|
||||
*launch_dimensions, b)
|
||||
.EmitLoop(name));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(LoopEmitter(compare_loop_body_emitter, compare_shape, b)
|
||||
.EmitLoop(name));
|
||||
}
|
||||
|
||||
// Set the IR builder insert point to the exit basic block of the outer most
|
||||
// loop. This ensures later instructions are inserted after this loop nest.
|
||||
|
@ -16,6 +16,8 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_SORT_UTIL_H_
|
||||
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/stringpiece.h"
|
||||
@ -23,10 +25,14 @@ limitations under the License.
|
||||
|
||||
namespace xla {
|
||||
namespace llvm_ir {
|
||||
// Emits llvm IR to sort the 'dimension_to_sort' dimension of 'keys_array' into
|
||||
// ascending order.
|
||||
// Emits llvm IR to do pairwise comparisons/swaps in the 'dimension_to_sort'
|
||||
// dimension of 'keys_array'. All other dimensions are kept as-is. This
|
||||
// implements the inner loop of BitonicSort. If 'launch_dimensions' is nullptr,
|
||||
// the inner compare loop will not be parallelized.
|
||||
Status EmitSortInPlace(int64 dimension_to_sort, const IrArray& keys_array,
|
||||
tensorflow::StringPiece name, llvm::IRBuilder<>* b);
|
||||
tensorflow::StringPiece name, llvm::Value* xor_mask,
|
||||
llvm::IRBuilder<>* b,
|
||||
const gpu::LaunchDimensions* launch_dimensions);
|
||||
} // namespace llvm_ir
|
||||
} // namespace xla
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user