diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc index ffbd0d68ce9..23f5a5c434f 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -31,9 +31,15 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, - llvm::Type* index_type) { + llvm::Type* index_type, + llvm::Value* base_index) { CHECK_NE(index_type, nullptr); + CHECK_EQ(base_index, nullptr) + << "XLA CPU implementation of" + << " ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support" + << " base_index, but it was requested."; + CHECK(!shape_.IsTuple()); CHECK(!ShapeUtil::IsScalar(shape_)); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h index a604e1db222..a11fd44f1ce 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -61,7 +61,8 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index) override; private: const DynamicLoopBounds* dynamic_loop_bounds_; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 913247a4299..fdd36affc2b 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -323,6 +323,7 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 10a565308de..21b4ef40d97 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -610,6 +610,11 @@ static StatusOr DeviceCompare(se::Stream* stream, executor->GetDeviceDescription().threads_per_block_limit(); gpu_device_info.threads_per_warp = executor->GetDeviceDescription().threads_per_warp(); + gpu_device_info.shared_memory_per_block = + executor->GetDeviceDescription().shared_memory_per_block(); + gpu_device_info.threads_per_core_limit = + executor->GetDeviceDescription().threads_per_core_limit(); + gpu_device_info.core_count = executor->GetDeviceDescription().core_count(); LaunchDimensions dim = CalculateLaunchDimensions(buffer_shape, gpu_device_info); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index b74dbd6100a..f14018f9982 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -617,6 +617,9 @@ StatusOr> GpuCompiler::RunBackend( stream_exec->GetDeviceDescription().threads_per_warp(); gpu_device_info.shared_memory_per_block = stream_exec->GetDeviceDescription().shared_memory_per_block(); + gpu_device_info.threads_per_core_limit = + stream_exec->GetDeviceDescription().threads_per_core_limit(); + gpu_device_info.core_count = stream_exec->GetDeviceDescription().core_count(); absl::optional cuda_compute_capability = [&]() -> absl::optional { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_device_info.h b/tensorflow/compiler/xla/service/gpu/gpu_device_info.h index 7352bad1a66..afb773c4527 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_device_info.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_device_info.h @@ -32,6 +32,8 @@ struct GpuDeviceInfo { int threads_per_block_limit; int threads_per_warp; int shared_memory_per_block; + int threads_per_core_limit; + int core_count; }; } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index e44cfd45dfc..05985b55823 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2303,7 +2303,7 @@ StatusOr> IrEmitterUnnested::BuildConditionalThunk( Status IrEmitterUnnested::EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk, - int unroll_factor) { + int unroll_factor, bool few_waves) { VLOG(3) << bindings_.ToString(); bool multi_output = hlo.shape().IsTuple(); @@ -2314,7 +2314,8 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( << ShapeUtil::HumanStringWithLayout(hlo.shape()) << " for unroll_factor " << unroll_factor; LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor); + element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor, + few_waves); UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); if (!multi_output) { @@ -2400,8 +2401,27 @@ Status IrEmitterUnnested::EmitTargetElementLoop( std::unique_ptr kernel_thunk = BuildKernelThunk(&hlo, /*implements_whole_instruction=*/true); + + // Check if we want to schedule grid size that has fewer SM waves. + // This speed up computations in some cases. + bool few_waves = false; + auto few_waves_allow_instr = [](const HloInstruction* instr) { + return instr->IsElementwise() || instr->opcode() == HloOpcode::kParameter || + // We need to make the codegen broadcast aware before enabling + // more broadcast pattern. + (instr->opcode() == HloOpcode::kBroadcast && + instr->dimensions().empty()); + }; + if (hlo.opcode() == HloOpcode::kFusion) { + few_waves = + absl::c_all_of(hlo.fused_instructions_computation()->instructions(), + few_waves_allow_instr); + } else { + few_waves = few_waves_allow_instr(&hlo); + } + Status emit_status = EmitTargetElementLoopInThunk( - hlo, body_emitter, kernel_thunk.get(), unroll_factor); + hlo, body_emitter, kernel_thunk.get(), unroll_factor, few_waves); thunk_sequence_.emplace_back(std::move(kernel_thunk)); return emit_status; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index c36f0b7840d..b83af8799d3 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -178,7 +178,7 @@ class IrEmitterUnnested : public IrEmitter, // `unroll_factor` is greater than one. Status EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, - KernelThunk* thunk, int unroll_factor); + KernelThunk* thunk, int unroll_factor, bool few_waves = false); Status Postprocess(HloInstruction* hlo) override; diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc index c23e8112cb0..5dbbb2d65da 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc @@ -56,7 +56,7 @@ static int64 ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info) { // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions(const Shape& shape, GpuDeviceInfo gpu_device_info, - int unroll_factor) { + int unroll_factor, bool few_waves) { int64 num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); @@ -90,6 +90,11 @@ LaunchDimensions CalculateLaunchDimensions(const Shape& shape, } int64 block_count = CeilOfRatio(num_elements, threads_per_block); + if (few_waves) { + threads_per_block = std::min(threads_per_block, int64{128}); + block_count = gpu_device_info.core_count * + (gpu_device_info.threads_per_core_limit / threads_per_block); + } VLOG(2) << absl::StrFormat( "Initialized the block count to ceil(# of elements / threads per " "block) = ceil(%d/%d) = %d", diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h index dbe5a037e43..1472141a80e 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h @@ -67,7 +67,8 @@ std::ostream& operator<<(std::ostream& out, LaunchDimensions CalculateLaunchDimensions(const Shape& shape, GpuDeviceInfo gpu_device_info, - int unroll_factor = 1); + int unroll_factor = 1, + bool few_waves = false); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 6b7b31e8288..45c4f25d8e8 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/gpu/target_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -58,7 +59,8 @@ ParallelLoopEmitter::ParallelLoopEmitter( std::vector ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, - llvm::Type* index_type) { + llvm::Type* index_type, + llvm::Value* base_index) { // Emit the following code in LLVM IR: // linear_index = blockIdx.x * blockDim.x + threadIdx.x; // if (linear_index < num_elements) { @@ -122,6 +124,12 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, "linear_index_base", /*HasNUW=*/true, /*HasNSW=*/true); } + if (base_index != nullptr) { + linear_index_base = + b_->CreateAdd(linear_index_base, base_index, "linear_index_plus_base", + /*HasNUW=*/true, /*HasNSW=*/true); + } + array_indices.emplace_back(linear_index_base, shape_, b_); for (int i = 1; i < unroll_factor_; ++i) { llvm::Value* linear_index = @@ -147,5 +155,43 @@ ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(absl::string_view loop_name, return array_indices; } +Status ParallelLoopEmitter::EmitLoop(absl::string_view loop_name, + llvm::Type* index_type) { + if (index_type == nullptr) { + index_type = b_->getInt64Ty(); + } + int64 total_threads = launch_dimensions_.launch_bound(); + int64 num_elements = ShapeUtil::ElementsIn(shape_); + // If all the elements are handled by the current threads, no need + // to add a loop inside the kernel. + if (total_threads * unroll_factor_ >= num_elements) { + VLOG(1) << "ParallelLoopEmitter::EmitLoop fallback"; + return LoopEmitter::EmitLoop(loop_name, index_type); + } + + KernelSupportLibrary ksl(b_, llvm_ir::UnrollMode::kDefaultUnroll); + auto constant = [&](int64 val) { + return llvm::ConstantInt::get(index_type, val); + }; + + TF_RETURN_IF_ERROR(ksl.ForWithStatus( + "loop", constant(0), constant(num_elements), + constant(total_threads * unroll_factor_), [&](llvm::Value* base_indvar) { + for (const llvm_ir::IrArray::Index& array_index : + EmitIndexAndSetExitBasicBlock(loop_name, index_type, + base_indvar)) { + TF_RETURN_IF_ERROR(body_emitter_(array_index)); + } + return Status::OK(); + })); + + // Set the insertion point of b_ to the loop exit, so that + // code emitted for later instructions will be correctly placed. + if (exit_bb_ != nullptr) { + b_->SetInsertPoint(exit_bb_); + } + return Status::OK(); +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 0a6b5430b23..5e142ec3832 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -57,7 +57,11 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { ~ParallelLoopEmitter() override = default; std::vector EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type) override; + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index) override; + + Status EmitLoop(absl::string_view loop_name = "", + llvm::Type* index_type = nullptr); private: // The thread and block dimension to parallelize the loop on. diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc index 2f139563b4a..e6037f4cac6 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_unrolling_test.cc @@ -148,8 +148,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedSine) { HloModule test_module ENTRY SineFunc { - p0 = f32[160000]{0} parameter(0) - ROOT s = f32[160000]{0} sine(p0) + p0 = f32[1600000]{0} parameter(0) + ROOT s = f32[1600000]{0} sine(p0) })"; auto hlo_module = ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie(); @@ -182,8 +182,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedCosine) { HloModule test_module ENTRY SineFunc { - p0 = f32[160000]{0} parameter(0) - ROOT s = f32[160000]{0} cosine(p0) + p0 = f32[1600000]{0} parameter(0) + ROOT s = f32[1600000]{0} cosine(p0) })"; auto hlo_module = ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie(); @@ -216,8 +216,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedPower) { HloModule test_module ENTRY SineFunc { - p0 = f32[160000]{0} parameter(0) - ROOT s = f32[160000]{0} power(p0, p0) + p0 = f32[1600000]{0} parameter(0) + ROOT s = f32[1600000]{0} power(p0, p0) })"; auto hlo_module = ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie(); @@ -241,8 +241,8 @@ TEST_F(GpuUnrollingTest, DisabledUnrollUnfusedAtan2) { HloModule test_module ENTRY SineFunc { - p0 = f32[160000]{0} parameter(0) - ROOT s = f32[160000]{0} atan2(p0, p0) + p0 = f32[1600000]{0} parameter(0) + ROOT s = f32[1600000]{0} atan2(p0, p0) })"; auto hlo_module = ParseAndReturnVerifiedModule(kUnfusedAddModule, config).ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index b6b3b2dd8b3..9d7f06f4f68 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -130,8 +130,14 @@ IrArray::Index LoopEmitter::EmitDynamicIndex(ForLoopNest* loop_nest, } std::vector LoopEmitter::EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type) { + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index) { CHECK_NE(index_type, nullptr); + CHECK_EQ(base_index, nullptr) + << "XLA CPU implementation of" + << " LoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support" + << " base_index, but it was requested."; + if (ShapeUtil::IsScalar(shape_)) { // No loop needed, so set exit_bb_ to nullptr. exit_bb_ = nullptr; @@ -164,7 +170,8 @@ Status LoopEmitter::EmitLoop(absl::string_view loop_name, } for (const IrArray::Index& array_index : - EmitIndexAndSetExitBasicBlock(loop_name, index_type)) { + EmitIndexAndSetExitBasicBlock(loop_name, index_type, + /*base_index*/ nullptr)) { TF_RETURN_IF_ERROR(body_emitter_(array_index)); } diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 008205a642a..a356741f74b 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -71,11 +71,13 @@ class LoopEmitter { // specifies the element, will return multiple indices if the loop is // unrolled. std::vector EmitIndexAndSetExitBasicBlock() { - return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", b_->getInt64Ty()); + return EmitIndexAndSetExitBasicBlock(/*loop_name=*/"", b_->getInt64Ty(), + /*base_index*/ nullptr); } virtual std::vector EmitIndexAndSetExitBasicBlock( - absl::string_view loop_name, llvm::Type* index_type); + absl::string_view loop_name, llvm::Type* index_type, + llvm::Value* base_index); // Emits a complete loop nest for every element in the given shape. Status EmitLoop(absl::string_view loop_name = "",