[XLA] Speed up. Make XLA faster by making PW kernel use the right number of block and loops.
This commit is contained in:
parent
c257a5d210
commit
2cd7d60b98
@ -31,9 +31,15 @@ ParallelLoopEmitter::ParallelLoopEmitter(
|
||||
|
||||
std::vector<llvm_ir::IrArray::Index>
|
||||
ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
||||
llvm::Type* index_type) {
|
||||
llvm::Type* index_type,
|
||||
llvm::Value* base_index) {
|
||||
CHECK_NE(index_type, nullptr);
|
||||
|
||||
CHECK_EQ(base_index, nullptr)
|
||||
<< "XLA CPU implementation of"
|
||||
<< " ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support"
|
||||
<< " base_index, but it was requested.";
|
||||
|
||||
CHECK(!shape_.IsTuple());
|
||||
CHECK(!ShapeUtil::IsScalar(shape_));
|
||||
|
||||
|
||||
@ -61,7 +61,8 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
|
||||
~ParallelLoopEmitter() override = default;
|
||||
|
||||
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
|
||||
absl::string_view loop_name, llvm::Type* index_type) override;
|
||||
absl::string_view loop_name, llvm::Type* index_type,
|
||||
llvm::Value* base_index = nullptr) override;
|
||||
|
||||
private:
|
||||
const DynamicLoopBounds* dynamic_loop_bounds_;
|
||||
|
||||
@ -311,6 +311,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:ir_array",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
|
||||
|
||||
@ -610,6 +610,12 @@ static StatusOr<bool> DeviceCompare(se::Stream* stream,
|
||||
executor->GetDeviceDescription().threads_per_block_limit();
|
||||
gpu_device_info.threads_per_warp =
|
||||
executor->GetDeviceDescription().threads_per_warp();
|
||||
gpu_device_info.shared_memory_per_block =
|
||||
executor->GetDeviceDescription().shared_memory_per_block();
|
||||
gpu_device_info.threads_per_core_limit =
|
||||
executor->GetDeviceDescription().threads_per_core_limit();
|
||||
gpu_device_info.core_count =
|
||||
executor->GetDeviceDescription().core_count();
|
||||
LaunchDimensions dim =
|
||||
CalculateLaunchDimensions(buffer_shape, gpu_device_info);
|
||||
|
||||
|
||||
@ -611,6 +611,10 @@ StatusOr<std::unique_ptr<Executable>> GpuCompiler::RunBackend(
|
||||
stream_exec->GetDeviceDescription().threads_per_warp();
|
||||
gpu_device_info.shared_memory_per_block =
|
||||
stream_exec->GetDeviceDescription().shared_memory_per_block();
|
||||
gpu_device_info.threads_per_core_limit =
|
||||
stream_exec->GetDeviceDescription().threads_per_core_limit();
|
||||
gpu_device_info.core_count =
|
||||
stream_exec->GetDeviceDescription().core_count();
|
||||
|
||||
absl::optional<CudaComputeCapability> cuda_compute_capability =
|
||||
[&]() -> absl::optional<CudaComputeCapability> {
|
||||
|
||||
@ -32,6 +32,8 @@ struct GpuDeviceInfo {
|
||||
int threads_per_block_limit;
|
||||
int threads_per_warp;
|
||||
int shared_memory_per_block;
|
||||
int threads_per_core_limit;
|
||||
int core_count;
|
||||
};
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
@ -87,6 +87,9 @@ LaunchDimensions CalculateLaunchDimensions(const Shape& shape,
|
||||
}
|
||||
|
||||
int64 block_count = CeilOfRatio(num_elements, threads_per_block);
|
||||
threads_per_block = std::min(threads_per_block, 128LL);
|
||||
block_count = gpu_device_info.core_count * (gpu_device_info.threads_per_core_limit /
|
||||
threads_per_block);
|
||||
VLOG(2) << absl::StrFormat(
|
||||
"Initialized the block count to ceil(# of elements / threads per "
|
||||
"block) = ceil(%d/%d) = %d",
|
||||
|
||||
@ -23,6 +23,7 @@ limitations under the License.
|
||||
#include "llvm/IR/Intrinsics.h"
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/target_util.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
|
||||
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
@ -58,7 +59,8 @@ ParallelLoopEmitter::ParallelLoopEmitter(
|
||||
|
||||
std::vector<llvm_ir::IrArray::Index>
|
||||
ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
||||
llvm::Type* index_type) {
|
||||
llvm::Type* index_type,
|
||||
llvm::Value* base_index) {
|
||||
// Emit the following code in LLVM IR:
|
||||
// linear_index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// if (linear_index < num_elements) {
|
||||
@ -121,6 +123,12 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
||||
"linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true);
|
||||
}
|
||||
|
||||
if (base_index != nullptr) {
|
||||
linear_index_base = b_->CreateAdd(linear_index_base, base_index,
|
||||
"linear_index_plus_base",
|
||||
/*HasNUW=*/true, /*HasNSW=*/true);
|
||||
}
|
||||
|
||||
array_indices.emplace_back(linear_index_base, shape_, b_);
|
||||
for (int i = 1; i < unroll_factor_; ++i) {
|
||||
llvm::Value* linear_index =
|
||||
@ -146,5 +154,44 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,
|
||||
return array_indices;
|
||||
}
|
||||
|
||||
Status ParallelLoopEmitter::EmitLoop(absl::string_view loop_name,
|
||||
llvm::Type* index_type) {
|
||||
if (index_type == nullptr) {
|
||||
index_type = b_->getInt64Ty();
|
||||
}
|
||||
int64 total_threads = launch_dimensions_.block_count() *
|
||||
launch_dimensions_.threads_per_block();
|
||||
int64 num_elements = ShapeUtil::ElementsIn(shape_);
|
||||
// If all the elements are handled by the current threads, no need
|
||||
// to add a loop inside the kernel.
|
||||
if (total_threads * unroll_factor_ >= num_elements) {
|
||||
VLOG(1) << "ParallelLoopEmitter::EmitLoop fallback";
|
||||
return LoopEmitter::EmitLoop(loop_name, index_type);
|
||||
}
|
||||
|
||||
KernelSupportLibrary ksl(b_, llvm_ir::UnrollMode::kDefaultUnroll);
|
||||
auto constant = [&](int64 val) {
|
||||
return llvm::ConstantInt::get(index_type, val);
|
||||
};
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
ksl.ForWithStatus("loop", constant(0), constant(num_elements),
|
||||
constant(total_threads * unroll_factor_),
|
||||
[&] (llvm::Value* base_indvar) {
|
||||
for (const llvm_ir::IrArray::Index& array_index :
|
||||
EmitIndexAndSetExitBasicBlock(loop_name, index_type, base_indvar)) {
|
||||
TF_RETURN_IF_ERROR(body_emitter_(array_index));
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
|
||||
// Set the insertion point of b_ to the loop exit, so that
|
||||
// code emitted for later instructions will be correctly placed.
|
||||
if (exit_bb_ != nullptr) {
|
||||
b_->SetInsertPoint(exit_bb_);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
||||
@ -57,8 +57,10 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
|
||||
~ParallelLoopEmitter() override = default;
|
||||
|
||||
std::vector<llvm_ir::IrArray::Index> EmitIndexAndSetExitBasicBlock(
|
||||
absl::string_view loop_name, llvm::Type* index_type) override;
|
||||
absl::string_view loop_name, llvm::Type* index_type, llvm::Value* base_index);
|
||||
|
||||
Status EmitLoop(absl::string_view loop_name = "",
|
||||
llvm::Type* index_type = nullptr);
|
||||
private:
|
||||
// The thread and block dimension to parallelize the loop on.
|
||||
const LaunchDimensions launch_dimensions_;
|
||||
|
||||
@ -130,8 +130,13 @@ IrArray::Index LoopEmitter::EmitDynamicIndex(ForLoopNest* loop_nest,
|
||||
}
|
||||
|
||||
std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
|
||||
absl::string_view loop_name, llvm::Type* index_type) {
|
||||
absl::string_view loop_name, llvm::Type* index_type, llvm::Value* base_index) {
|
||||
CHECK_NE(index_type, nullptr);
|
||||
CHECK_EQ(base_index, nullptr)
|
||||
<< "XLA CPU implementation of"
|
||||
<< " LoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support"
|
||||
<< " base_index, but it was requested.";
|
||||
|
||||
if (ShapeUtil::IsScalar(shape_)) {
|
||||
// No loop needed, so set exit_bb_ to nullptr.
|
||||
exit_bb_ = nullptr;
|
||||
|
||||
@ -75,7 +75,8 @@ class LoopEmitter {
|
||||
}
|
||||
|
||||
virtual std::vector<IrArray::Index> EmitIndexAndSetExitBasicBlock(
|
||||
absl::string_view loop_name, llvm::Type* index_type);
|
||||
absl::string_view loop_name, llvm::Type* index_type,
|
||||
llvm::Value* base_index = nullptr);
|
||||
|
||||
// Emits a complete loop nest for every element in the given shape.
|
||||
Status EmitLoop(absl::string_view loop_name = "",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user