diff --git a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc index 21b4ef40d97..0755093631c 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_comparator.cc @@ -615,8 +615,15 @@ static StatusOr DeviceCompare(se::Stream* stream, gpu_device_info.threads_per_core_limit = executor->GetDeviceDescription().threads_per_core_limit(); gpu_device_info.core_count = executor->GetDeviceDescription().core_count(); - LaunchDimensions dim = - CalculateLaunchDimensions(buffer_shape, gpu_device_info); + gpu_device_info.block_dim_limit_x = + executor->GetDeviceDescription().block_dim_limit().x; + gpu_device_info.block_dim_limit_y = + executor->GetDeviceDescription().block_dim_limit().y; + gpu_device_info.block_dim_limit_z = + executor->GetDeviceDescription().block_dim_limit().z; + + TF_ASSIGN_OR_RETURN(LaunchDimensions dim, + CalculateLaunchDimensions(buffer_shape, gpu_device_info)); LaunchDimensions::Dim3D thread_counts = dim.thread_counts_per_block(); LaunchDimensions::Dim3D block_counts = dim.block_counts(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index d104c26964b..c4a180424ea 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -971,6 +971,12 @@ GpuDeviceInfo GetGpuDeviceInfo(se::StreamExecutor* stream_exec) { gpu_device_info.threads_per_core_limit = stream_exec->GetDeviceDescription().threads_per_core_limit(); gpu_device_info.core_count = stream_exec->GetDeviceDescription().core_count(); + gpu_device_info.block_dim_limit_x = + stream_exec->GetDeviceDescription().block_dim_limit().x; + gpu_device_info.block_dim_limit_y = + stream_exec->GetDeviceDescription().block_dim_limit().y; + gpu_device_info.block_dim_limit_z = + stream_exec->GetDeviceDescription().block_dim_limit().z; return gpu_device_info; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_device_info.h b/tensorflow/compiler/xla/service/gpu/gpu_device_info.h index afb773c4527..a953ff80fc9 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_device_info.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_device_info.h @@ -34,6 +34,9 @@ struct GpuDeviceInfo { int shared_memory_per_block; int threads_per_core_limit; int core_count; + int block_dim_limit_x; + int block_dim_limit_y; + int block_dim_limit_z; }; } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index feb72494238..6a5dd1ecba2 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -913,8 +913,10 @@ Status IrEmitterUnnested::EmitPadToStaticFromMlir(MlirEmitterInput mlir_input) { return Status::OK(); }; - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - input_shape, ir_emitter_context_->gpu_device_info(), unroll_factor); + TF_ASSIGN_OR_RETURN( + LaunchDimensions launch_dimensions, + CalculateLaunchDimensions( + input_shape, ir_emitter_context_->gpu_device_info(), unroll_factor)); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); TF_RETURN_IF_ERROR( @@ -1036,8 +1038,10 @@ Status IrEmitterUnnested::EmitSliceToDynamicFromMlir( return Status::OK(); }; - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - input_shape, ir_emitter_context_->gpu_device_info(), unroll_factor); + TF_ASSIGN_OR_RETURN( + LaunchDimensions launch_dimensions, + CalculateLaunchDimensions( + input_shape, ir_emitter_context_->gpu_device_info(), unroll_factor)); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); @@ -1830,9 +1834,10 @@ Status IrEmitterUnnested::EmitLoopFusionFromMlir( }(); Shape element_shape = context.output_shapes[0]; - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor, - few_waves); + TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, + CalculateLaunchDimensions( + element_shape, ir_emitter_context_->gpu_device_info(), + unroll_factor, few_waves)); UpdateLaunchDimensions(launch_dimensions, kernel_thunk, ir_emitter_context_->llvm_module()); llvm::Type* index_type = GetIndexTypeForKernelFromMlir( @@ -1907,9 +1912,11 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { auto unroll_factor = ComputeMaxUnrollFactor(fusion_op, hlo_module_config_); const Shape& element_shape = root->shape(); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - element_shape, ir_emitter_context_->gpu_device_info(), - unroll_factor, /*few_waves=*/false); + TF_ASSIGN_OR_RETURN( + LaunchDimensions launch_dimensions, + CalculateLaunchDimensions(element_shape, + ir_emitter_context_->gpu_device_info(), + unroll_factor, /*few_waves=*/false)); UpdateLaunchDimensions(launch_dimensions, thunks.back().get(), ir_emitter_context_->llvm_module()); TF_RETURN_IF_ERROR( @@ -2033,8 +2040,10 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { // same as operand 0's array. const IrArray& output_array = ir_arrays.back(); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - update_shape, ir_emitter_context_->gpu_device_info()); + TF_ASSIGN_OR_RETURN( + LaunchDimensions launch_dimensions, + CalculateLaunchDimensions(update_shape, + ir_emitter_context_->gpu_device_info())); UpdateLaunchDimensions(launch_dimensions, fusion_thunk.get(), ir_emitter_context_->llvm_module()); AddThunkToThunkSequence(std::move(fusion_thunk)); @@ -2229,8 +2238,10 @@ Status IrEmitterUnnested::EmitSelectAndScatterFromMlir( TypeToShape(select_and_scatter_op.operand().getType()); const int64 rank = operand_shape.rank(); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - source_shape, ir_emitter_context_->gpu_device_info()); + TF_ASSIGN_OR_RETURN( + LaunchDimensions launch_dimensions, + CalculateLaunchDimensions(source_shape, + ir_emitter_context_->gpu_device_info())); llvm::Type* index_type = GetIndexTypeForKernelFromMlir( select_and_scatter_op, launch_dimensions.launch_bound(), &b_); auto index_typed_constant = [&](uint64 c) -> llvm::Constant* { @@ -2713,8 +2724,10 @@ Status IrEmitterUnnested::EmitScatter(const ScatterDescriptor& desc, // Launch a kernel that reads every element in the updates tensor. We could // also do one kernel per window instead if bounds checks turn out to be a // bottleneck. - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - desc.updates_shape, ir_emitter_context_->gpu_device_info()); + TF_ASSIGN_OR_RETURN( + LaunchDimensions launch_dimensions, + CalculateLaunchDimensions(desc.updates_shape, + ir_emitter_context_->gpu_device_info())); UpdateLaunchDimensions(launch_dimensions, thunk, ir_emitter_context_->llvm_module()); @@ -2922,8 +2935,10 @@ Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) { uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1); standard_iteration_shape.set_dimensions(dimension_to_sort, standard_num_iterations_in_sort_dim); - LaunchDimensions standard_launch_dimensions = CalculateLaunchDimensions( - standard_iteration_shape, ir_emitter_context_->gpu_device_info()); + TF_ASSIGN_OR_RETURN( + LaunchDimensions standard_launch_dimensions, + CalculateLaunchDimensions(standard_iteration_shape, + ir_emitter_context_->gpu_device_info())); // Calculate the launch dimensions for the case where we use tiling. We split // the dimension that should be sorted into tiles of size 'kTileSize'. This @@ -3664,8 +3679,9 @@ IrEmitterUnnested::BuildInitializerThunkForMlir(mlir::Operation* op, const llvm_ir::IrArray dest_array = ir_arrays[1]; const Shape dest_shape = TypeToShape(dest.getType()); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - dest_shape, ir_emitter_context_->gpu_device_info()); + TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, + CalculateLaunchDimensions( + dest_shape, ir_emitter_context_->gpu_device_info())); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); @@ -3708,8 +3724,9 @@ IrEmitterUnnested::BuildFusedInitializerThunkForMlir( ir_arrays[input_buffers.size() + output_index]; const Shape dest_shape = TypeToShape(dest.getType()); - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - dest_shape, ir_emitter_context_->gpu_device_info()); + TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, + CalculateLaunchDimensions( + dest_shape, ir_emitter_context_->gpu_device_info())); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); @@ -5802,9 +5819,10 @@ Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices( TF_ASSIGN_OR_RETURN(Shape element_shape, GetConsistentInputShapeForRootSlices(fused_computation)); - - LaunchDimensions launch_dimensions = CalculateLaunchDimensions( - element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor); + TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions, + CalculateLaunchDimensions( + element_shape, ir_emitter_context_->gpu_device_info(), + unroll_factor)); UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(), ir_emitter_context_->llvm_module()); diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc index e2346224ee0..a9061f1fb10 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.cc @@ -54,9 +54,9 @@ static int64 ThreadsPerBlockLimit(GpuDeviceInfo gpu_device_info) { } // Calculates the launch dimensions used to invoke `hlo`. -LaunchDimensions CalculateLaunchDimensions(const Shape& shape, - GpuDeviceInfo gpu_device_info, - int unroll_factor, bool few_waves) { +StatusOr CalculateLaunchDimensions( + const Shape& shape, GpuDeviceInfo gpu_device_info, int unroll_factor, + bool few_waves) { int64 num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); @@ -102,6 +102,15 @@ LaunchDimensions CalculateLaunchDimensions(const Shape& shape, block_count = capped_block_count; } } + + if (gpu_device_info.block_dim_limit_x > 0 && + block_count >= gpu_device_info.block_dim_limit_x) { + return tensorflow::errors::Unimplemented( + "Kernel launch needs more blocks (", block_count, + ") than allowed by hardware (", gpu_device_info.block_dim_limit_x, + ")."); + } + VLOG(2) << absl::StrFormat( "Initialized the block count to ceil(# of elements / threads per " "block) = ceil(%d/%d) = %d", diff --git a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h index 1472141a80e..7281f796409 100644 --- a/tensorflow/compiler/xla/service/gpu/launch_dimensions.h +++ b/tensorflow/compiler/xla/service/gpu/launch_dimensions.h @@ -65,10 +65,9 @@ class LaunchDimensions { std::ostream& operator<<(std::ostream& out, const LaunchDimensions& launch_dims); -LaunchDimensions CalculateLaunchDimensions(const Shape& shape, - GpuDeviceInfo gpu_device_info, - int unroll_factor = 1, - bool few_waves = false); +StatusOr CalculateLaunchDimensions( + const Shape& shape, GpuDeviceInfo gpu_device_info, int unroll_factor = 1, + bool few_waves = false); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/BUILD b/tensorflow/compiler/xla/service/gpu/tests/BUILD index ef299911153..20129a38152 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/BUILD +++ b/tensorflow/compiler/xla/service/gpu/tests/BUILD @@ -129,6 +129,21 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gpu_too_many_blocks_test", + srcs = [ + "gpu_too_many_blocks_test.cc", + ], + tags = tf_cuda_tests_tags(), + deps = [ + ":gpu_codegen_test", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "reduction_degenerate_dim_remover_test", srcs = [ diff --git a/tensorflow/compiler/xla/service/gpu/tests/gpu_too_many_blocks_test.cc b/tensorflow/compiler/xla/service/gpu/tests/gpu_too_many_blocks_test.cc new file mode 100644 index 00000000000..e5cfe9670ef --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/tests/gpu_too_many_blocks_test.cc @@ -0,0 +1,60 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace gpu { + +namespace { + +class TooManyBlocksTest : public GpuCodegenTest {}; + +TEST_F(TooManyBlocksTest, FailsWithInvalidStatus) { + const char* hlo_text = R"( +HloModule primitive_computation_mul.8 + +ENTRY primitive_computation_mul.8 { + parameter.1 = f32[4,1048576,1,1]{3,2,1,0} parameter(0) + reshape.3 = f32[4,1048576,1]{2,1,0} reshape(parameter.1) + broadcast.4 = f32[4,1048576,1048576,1]{3,2,1,0} broadcast(reshape.3), dimensions={0,1,3} + parameter.2 = f32[4,1,1048576,1]{3,2,1,0} parameter(1) + reshape.5 = f32[4,1048576,1]{2,1,0} reshape(parameter.2) + broadcast.6 = f32[4,1048576,1048576,1]{3,2,1,0} broadcast(reshape.5), dimensions={0,2,3} + ROOT multiply.7 = f32[4,1048576,1048576,1]{3,2,1,0} multiply(broadcast.4, broadcast.6) +} +)"; + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr optimized_module, + GetOptimizedModule(hlo_text)); + + StatusOr> failed_executable = + backend().compiler()->RunBackend( + std::move(optimized_module), backend().default_stream_executor(), + backend().default_stream_executor()->GetAllocator()); + + EXPECT_FALSE(failed_executable.ok()); + EXPECT_THAT(failed_executable.status().ToString(), + ::testing::HasSubstr("Kernel launch needs more blocks")); +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/tests/hlo_to_llvm_ir.cc b/tensorflow/compiler/xla/service/gpu/tests/hlo_to_llvm_ir.cc index d9d0555b425..0da2fbb1957 100644 --- a/tensorflow/compiler/xla/service/gpu/tests/hlo_to_llvm_ir.cc +++ b/tensorflow/compiler/xla/service/gpu/tests/hlo_to_llvm_ir.cc @@ -49,6 +49,9 @@ xla::Status CompileAndPrintLlvmIr(const std::string& hlo_text) { gpu_device_info.shared_memory_per_block = 49152; gpu_device_info.core_count = 80; gpu_device_info.threads_per_core_limit = 2048; + gpu_device_info.block_dim_limit_x = 2147483647; + gpu_device_info.block_dim_limit_y = 65535; + gpu_device_info.block_dim_limit_z = 65535; xla::gpu::CudaComputeCapability cuda_compute_capability; cuda_compute_capability.cc_major = 7;