From 2cd7d60b98764e41cbd3014291329e52c6c7966f Mon Sep 17 00:00:00 2001 From: Frederic Bastien Date: Mon, 27 Jul 2020 14:17:25 -0700 Subject: [PATCH] [XLA] Speed up. Make XLA faster by making PW kernel use the right number of block and loops. --- .../xla/service/cpu/parallel_loop_emitter.cc | 8 ++- .../xla/service/cpu/parallel_loop_emitter.h | 3 +- tensorflow/compiler/xla/service/gpu/BUILD | 1 + .../xla/service/gpu/buffer_comparator.cc | 6 +++ .../compiler/xla/service/gpu/gpu_compiler.cc | 4 ++ .../xla/service/gpu/gpu_device_info.h | 2 + .../xla/service/gpu/launch_dimensions.cc | 3 ++ .../xla/service/gpu/parallel_loop_emitter.cc | 49 ++++++++++++++++++- .../xla/service/gpu/parallel_loop_emitter.h | 4 +- .../xla/service/llvm_ir/loop_emitter.cc | 7 ++- .../xla/service/llvm_ir/loop_emitter.h | 3 +- 11 files changed, 84 insertions(+), 6 deletions(-) diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index ffbd0d68ce9..23f5a5c434f 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -31,9 +31,15 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, - llvm::Type* index_type) { + llvm::Type* index_type, + llvm::Value* base_index) { CHECK_NE(index_type, nullptr); + CHECK_EQ(base_index, nullptr) + << "XLA CPU implementation of" + << " ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support" + << " base_index, but it was requested."; + CHECK(!shape_.IsTuple()); CHECK(!ShapeUtil::IsScalar(shape_)); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index a604e1db222..931a49ea3c8 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -61,7 +61,8 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index = nullptr) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index d1d0827981e..d4bb4b7ebb7 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -311,6 +311,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 9b192aaa8e1..294d2c66565 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -610,6 +610,12 @@ static StatusOr DeviceCompare(se::Stream* stream, executor->GetDeviceDescription().threads_per_block_limit(); gpu_device_info.threads_per_warp = executor->GetDeviceDescription().threads_per_warp(); + gpu_device_info.shared_memory_per_block = + executor->GetDeviceDescription().shared_memory_per_block(); + gpu_device_info.threads_per_core_limit = + executor->GetDeviceDescription().threads_per_core_limit(); + gpu_device_info.core_count = + executor->GetDeviceDescription().core_count(); LaunchDimensions dim = CalculateLaunchDimensions(buffer_shape, gpu_device_info); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b2caa2ddcf4..5476419c54b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -611,6 +611,10 @@ StatusOr> GpuCompiler::RunBackend( stream_exec->GetDeviceDescription().threads_per_warp(); gpu_device_info.shared_memory_per_block = stream_exec->GetDeviceDescription().shared_memory_per_block(); + gpu_device_info.threads_per_core_limit = + stream_exec->GetDeviceDescription().threads_per_core_limit(); + gpu_device_info.core_count = + stream_exec->GetDeviceDescription().core_count(); absl::optional cuda_compute_capability = [&]() -> absl::optional { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_device_info.h b/tensorflow/compiler/xla/service/gpu/gpu_device_info.h index 7352bad1a66..afb773c4527 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_device_info.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_device_info.h @@ -32,6 +32,8 @@ struct GpuDeviceInfo { int threads_per_block_limit; int threads_per_warp; int shared_memory_per_block; + int threads_per_core_limit; + int core_count; }; } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc index 3668a521ec7..a0dac9e6f70 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc @@ -87,6 +87,9 @@ LaunchDimensions CalculateLaunchDimensions(const Shape& shape, } int64 block_count = CeilOfRatio(num_elements, threads_per_block); + threads_per_block = std::min(threads_per_block, 128LL); + block_count = gpu_device_info.core_count * (gpu_device_info.threads_per_core_limit / + threads_per_block); VLOG(2) << absl::StrFormat( "Initialized the block count to ceil(# of elements / threads per " "block) = ceil(%d/%d) = %d", diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index f9937ba77de..04ac260e55d 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/target_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -58,7 +59,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, - llvm::Type* index_type) { + llvm::Type* index_type, + llvm::Value* base_index) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { @@ -121,6 +123,12 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true); } + if (base_index != nullptr) { + linear_index_base = b_->CreateAdd(linear_index_base, base_index, + "linear_index_plus_base", + /*HasNUW=*/true, /*HasNSW=*/true); + } + array_indices.emplace_back(linear_index_base, shape_, b_); for (int i = 1; i < unroll_factor_; ++i) { llvm::Value* linear_index = @@ -146,5 +154,44 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, return array_indices; } +Status ParallelLoopEmitter::EmitLoop(absl::string_view loop_name, + llvm::Type* index_type) { + if (index_type == nullptr) { + index_type = b_->getInt64Ty(); + } + int64 total_threads = launch_dimensions_.block_count() * + launch_dimensions_.threads_per_block(); + int64 num_elements = ShapeUtil::ElementsIn(shape_); + // If all the elements are handled by the current threads, no need + // to add a loop inside the kernel. + if (total_threads * unroll_factor_ >= num_elements) { + VLOG(1) << "ParallelLoopEmitter::EmitLoop fallback"; + return LoopEmitter::EmitLoop(loop_name, index_type); + } + + KernelSupportLibrary ksl(b_, llvm_ir::UnrollMode::kDefaultUnroll); + auto constant = [&](int64 val) { + return llvm::ConstantInt::get(index_type, val); + }; + + TF_RETURN_IF_ERROR( + ksl.ForWithStatus("loop", constant(0), constant(num_elements), + constant(total_threads * unroll_factor_), + [&] (llvm::Value* base_indvar) { + for (const llvm_ir::IrArray::Index& array_index : + EmitIndexAndSetExitBasicBlock(loop_name, index_type, base_indvar)) { + TF_RETURN_IF_ERROR(body_emitter_(array_index)); + } + return Status::OK(); + })); + + // Set the insertion point of b_ to the loop exit, so that + // code emitted for later instructions will be correctly placed. + if (exit_bb_ != nullptr) { + b_->SetInsertPoint(exit_bb_); + } + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 0a6b5430b23..78b05a8189c 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -57,8 +57,10 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type, llvm::Value* base_index); + Status EmitLoop(absl::string_view loop_name = "", + llvm::Type* index_type = nullptr); private: // The thread and block dimension to parallelize the loop on. const LaunchDimensions launch_dimensions_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index b6b3b2dd8b3..98260209bd4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -130,8 +130,13 @@ IrArray::Index LoopEmitter::EmitDynamicIndex(ForLoopNest* loop_nest, } std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type) { + absl::string_view loop_name, llvm::Type* index_type, llvm::Value* base_index) { CHECK_NE(index_type, nullptr); + CHECK_EQ(base_index, nullptr) + << "XLA CPU implementation of" + << " LoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support" + << " base_index, but it was requested."; + if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. exit_bb_ = nullptr; diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 008205a642a..84392b1812e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -75,7 +75,8 @@ class LoopEmitter { } virtual std::vector EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type); + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index = nullptr); // Emits a complete loop nest for every element in the given shape. Status EmitLoop(absl::string_view loop_name = "",