From b21f969731f28fb1c65a5969251316e59f996127 Mon Sep 17 00:00:00 2001 From: jerryyin Date: Mon, 9 Sep 2019 18:16:50 +0000 Subject: [PATCH] Enabling ROCm parallel logic for gpu_conv_algorithm_picker --- .../xla/service/gpu/amdgpu_compiler.cc | 4 +- .../service/gpu/gpu_conv_algorithm_picker.cc | 215 ++++++++++++++---- .../service/gpu/gpu_conv_algorithm_picker.h | 25 +- .../xla/service/gpu/gpu_conv_runner.cc | 20 +- .../xla/service/gpu/nvptx_compiler.cc | 4 +- 5 files changed, 208 insertions(+), 60 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc index ed49fcd584f..10dd4542612 100644 --- a/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/amdgpu_compiler.cc @@ -16,9 +16,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h" #include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h" -// TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged. #include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/target_constants.h" @@ -97,7 +97,7 @@ Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment( options.set_is_layout_sensitive(true); pipeline.AddPass>(options); - // TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged. + pipeline.AddPass(stream_exec, device_allocator); // Clean up new_tuple described above. pipeline.AddPass(); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index 57a90283d75..762faea3418 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -48,6 +48,54 @@ using se::DeviceMemoryBase; using se::dnn::AlgorithmDesc; using tensorflow::AutotuneResult; +class ScratchAllocator : public se::ScratchAllocator { + public: + ScratchAllocator(int device_ordinal, + se::DeviceMemoryAllocator* memory_allocator) + : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} + + int64 GetMemoryLimitInBytes() override { + return 1LL << 32; // 4GB. TODO(jlebar): Tune this? + } + int64 TotalAllocatedBytes() { return total_allocated_bytes_; } + + StatusOr> AllocateBytes(int64 byte_size) override; + + template + StatusOr> Allocate(int64 num_elements) { + TF_ASSIGN_OR_RETURN(se::DeviceMemory bytes, + AllocateBytes(num_elements * sizeof(T))); + return se::DeviceMemory(bytes); + } + + private: + const int device_ordinal_; + se::DeviceMemoryAllocator* memory_allocator_; + std::vector allocated_buffers_; + int64 total_allocated_bytes_ = 0; +}; + +StatusOr> ScratchAllocator::AllocateBytes( + int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes()) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + absl::StrFormat( + "Allocating %d bytes exceeds the memory limit of %d bytes.", + byte_size, GetMemoryLimitInBytes())); + } + + TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer, + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false)); + total_allocated_bytes_ += byte_size; + + se::DeviceMemoryBase buffer_addr = *allocated_buffer; + allocated_buffers_.push_back(std::move(allocated_buffer)); + return se::DeviceMemory(buffer_addr); +} + std::vector GetAlgorithms(CudnnConvKind kind, se::StreamExecutor* stream_exec) { std::vector algorithms; @@ -198,7 +246,7 @@ auto& autotune_cache_stats GUARDED_BY(autotune_cache_lock) = *new ConvCacheStats(); } // anonymous namespace -StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( +StatusOr GpuConvAlgorithmPicker::PickBestAlgorithm( const HloCustomCallInstruction* instr) { // Don't run this function concurrently on the same GPU. // @@ -226,22 +274,6 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( autotune_cache_stats.cache_misses++; } - StatusOr result_or = PickBestAlgorithmNoCache(instr); - if (result_or.ok()) { - tensorflow::mutex_lock lock(autotune_cache_lock); - CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second); - } - return result_or; -} - -StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( - const HloCustomCallInstruction* instr) { - XLA_SCOPED_LOGGING_TIMER( - absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithmImpl for ", - instr->ToString())); - - const Shape& result_shape = instr->shape().tuple_shapes(0); - // Make sure any previous activity on this executor is done. We don't want to // interfere with programs that are still running on the GPU. if (!stream_exec_->SynchronizeAllActivity()) { @@ -269,6 +301,34 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( return &stream_opt.value(); }(); + StatusOr result_or(InternalError("Unknown platform.")); + // Check StreamExecutor on which platform it is. ROCm and Cuda implementation + // have diverged. Secifically, we need to make sure redzone allocator related + // utilities are not used in ROCm routine + if (stream_exec_->platform_kind() == se::PlatformKind::kROCm) { + result_or = PickBestAlgorithmNoCacheRocm(*instr, allocator, stream); + } else if (stream_exec_->platform_kind() == se::PlatformKind::kCuda) { + result_or = PickBestAlgorithmNoCacheCuda(*instr, allocator, stream); + } + + if (result_or.ok()) { + tensorflow::mutex_lock lock(autotune_cache_lock); + CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second); + } + return result_or; +} + +StatusOr +GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( + const HloCustomCallInstruction& instr, se::DeviceMemoryAllocator* allocator, + se::Stream* stream) { + // Right now Redzone allocator is available in Cuda target only + XLA_SCOPED_LOGGING_TIMER(absl::StrCat( + "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr.ToString())); + + const Shape& result_shape = instr.shape().tuple_shapes(0); + const auto device_ordinal = stream_exec_->device_ordinal(); + int64 rng_state = 0; const auto initialize_buffer = [&stream, &rng_state]( @@ -277,13 +337,13 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer); }; - const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); + const HloModuleConfig& hlo_module_config = instr.GetModule()->config(); // Allocate space for the input, filter, and output of the convolution. se::RedzoneAllocator input_output_allocator( stream, allocator, PtxOptsFromConfig(hlo_module_config)); std::vector operand_buffers; - for (const auto* operand : instr->operands()) { + for (const auto* operand : instr.operands()) { TF_ASSIGN_OR_RETURN(auto buffer, input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(operand->shape()))); @@ -296,7 +356,7 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( initialize_buffer(result_buffer, result_shape); TF_ASSIGN_OR_RETURN(auto backend_config, - instr->backend_config()); + instr.backend_config()); optional comparator; // Use the first algorithm that's supported as reference. There isn't a @@ -305,17 +365,17 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( se::DeviceMemoryBase reference_result_buffer; AlgorithmDesc first_algorithm; - TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(&instr)); std::vector profile_results; const DebugOptions& debug_options = - instr->GetModule()->config().debug_options(); + instr.GetModule()->config().debug_options(); const bool crash_on_checking_failure = debug_options.xla_gpu_crash_on_verification_failures(); const auto canonical_hlo = - std::get<1>(AutotuneCacheKeyfromInstruction(instr, stream_exec_)); + std::get<1>(AutotuneCacheKeyfromInstruction(&instr, stream_exec_)); string blas_version; if (auto* blas = stream_exec_->AsBlas()) { @@ -335,7 +395,7 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( if (absl::c_linear_search(blacklisted_algos, alg)) { LOG(INFO) << "Omitted potentially buggy algorithm " - << AlgorithmToString(alg) << " for conv " << instr->ToString(); + << AlgorithmToString(alg) << " for conv " << instr.ToString(); continue; } @@ -343,7 +403,7 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( stream, allocator, PtxOptsFromConfig(hlo_module_config)); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " - << instr->ToString(); + << instr.ToString(); // Use assignment instead of brace-list to make GCC 4.9 happy. RunConvOptions options; @@ -375,11 +435,11 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( // Check for writes to redzones. TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear, CheckRedzones(input_output_allocator, stream, - "input/output", instr, &result)); + "input/output", &instr, &result)); TF_ASSIGN_OR_RETURN( bool scratch_allocator_redzone_clear, - CheckRedzones(scratch_allocator, stream, "scratch", instr, &result)); + CheckRedzones(scratch_allocator, stream, "scratch", &instr, &result)); if (!input_output_allocator_redzone_clear || !scratch_allocator_redzone_clear) { @@ -410,7 +470,7 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( if (!compare_result.ok()) { LOG(ERROR) << "Unable to compare " << AlgorithmToString(first_algorithm) << " against " << AlgorithmToString(alg) << " for " - << instr->ToString() << ": " << compare_result.status(); + << instr.ToString() << ": " << compare_result.status(); if (compare_result.status().code() == tensorflow::error::RESOURCE_EXHAUSTED) { // Possibly OOM. Propatate the error. @@ -421,12 +481,11 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( LOG(ERROR) << "Results mismatch between different convolution algorithms. " "This is likely a bug/unexpected loss of precision in cudnn.\n" - << instr->ToString() << " for " - << AlgorithmToString(first_algorithm) << " vs " - << AlgorithmToString(alg); + << instr.ToString() << " for " << AlgorithmToString(first_algorithm) + << " vs " << AlgorithmToString(alg); PrintPlatformInfo(stream); VLOG(1) << "Full module on failure: \n" - << instr->GetModule()->ToString(); + << instr.GetModule()->ToString(); auto* fail = result.mutable_failure(); fail->set_kind(AutotuneResult::WRONG_RESULT); fail->set_buffer_address( @@ -453,11 +512,11 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( tensorflow::AutotuningLog log; { ConvInstructionLog instr_log; - *instr_log.mutable_instruction() = instr->ToProto(); - for (int i = 0; i < instr->operand_count(); i++) { - *instr_log.add_operand_shapes() = instr->operand(i)->shape().ToProto(); + *instr_log.mutable_instruction() = instr.ToProto(); + for (int i = 0; i < instr.operand_count(); i++) { + *instr_log.add_operand_shapes() = instr.operand(i)->shape().ToProto(); instr_log.add_operand_addresses( - reinterpret_cast(operand_buffers[i].opaque())); + reinterpret_cast((operand_buffers)[i].opaque())); } instr_log.set_result_address( reinterpret_cast(result_buffer.opaque())); @@ -523,11 +582,81 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache( return InternalError( "All algorithms tried for convolution %s failed. Falling back to " "default algorithm.", - instr->ToString()); + instr.ToString()); } -StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( - HloInstruction* instr) { +StatusOr +GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( + const HloCustomCallInstruction& instr, se::DeviceMemoryAllocator* allocator, + se::Stream* stream) { + XLA_SCOPED_LOGGING_TIMER(absl::StrCat( + "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr.ToString())); + + const auto device_ordinal = stream_exec_->device_ordinal(); + std::vector operand_buffers; + + ScratchAllocator input_output_allocator(device_ordinal, allocator); + const auto initialize_buffer = [stream](DeviceMemoryBase buffer) { + // Although we don't have evidence this matters, zero out the buffers + // before autotuning. It's conceivable that using uninitialized memory as + // the inputs might affect performance if e.g. the inputs contain + // denormals, and this is easy enough. + stream->ThenMemZero(&buffer, buffer.size()); + }; + + // Allocate space for the input, filter, and output of the convolution. We + // use a ScratchAllocator for this instead of calling allocator_ directly so + // that our allocations don't leak. + for (const auto* operand : instr.operands()) { + TF_ASSIGN_OR_RETURN(auto buffer, + input_output_allocator.AllocateBytes( + ShapeUtil::ByteSizeOf(operand->shape()))); + initialize_buffer(buffer); + operand_buffers.push_back(buffer); + } + + TF_ASSIGN_OR_RETURN( + auto result_buffer, + input_output_allocator.AllocateBytes( + ShapeUtil::ByteSizeOf(instr.shape().tuple_shapes(0)))); + initialize_buffer(result_buffer); + + ScratchAllocator scratch_allocator(device_ordinal, allocator); + se::dnn::ProfileResult profile_result; + VLOG(3) << "Auto-tuning for " << instr.ToString(); + RunConvOptions options; + options.profile_result = &profile_result; + + // ROCm: Set the overriding algorithm to empty to remind cudnn_conv_runner + // that the AlgorithmConfig in running convolution needs to be empty + options.algo_override = se::dnn::AlgorithmDesc(); + + bool launch_ok = + RunCudnnConv(&instr, absl::MakeSpan(operand_buffers), result_buffer, + &scratch_allocator, stream, options) + .ok(); + + AutotuneResult best_result; + if (launch_ok && profile_result.is_valid()) { + best_result.mutable_conv()->set_algorithm( + profile_result.algorithm().algo_id()); + best_result.mutable_conv()->set_tensor_ops_enabled( + profile_result.algorithm().tensor_ops_enabled()); + int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); + best_result.set_scratch_bytes(scratch_bytes_used); + *best_result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto( + absl::Milliseconds(profile_result.elapsed_time_in_ms())); + + return best_result; + } + + return InternalError( + "All algorithms tried for convolution %s failed. Falling back to " + "default algorithm.", + instr.ToString()); +} + +StatusOr GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) { CHECK(IsCustomCallToDnnConvolution(*instr)); StatusOr best_algo_or = @@ -577,7 +706,7 @@ StatusOr CudnnConvAlgorithmPicker::RunOnInstruction( return true; } -StatusOr CudnnConvAlgorithmPicker::RunOnComputation( +StatusOr GpuConvAlgorithmPicker::RunOnComputation( HloComputation* computation) { std::vector convs; for (auto* instr : computation->instructions()) { @@ -594,11 +723,11 @@ StatusOr CudnnConvAlgorithmPicker::RunOnComputation( return changed; } -StatusOr CudnnConvAlgorithmPicker::Run(HloModule* module) { - XLA_SCOPED_LOGGING_TIMER("CudnnConvAlgorithmPicker"); +StatusOr GpuConvAlgorithmPicker::Run(HloModule* module) { + XLA_SCOPED_LOGGING_TIMER("GpuConvAlgorithmPicker"); if (module->config().debug_options().xla_gpu_disable_autotune()) { - VLOG(2) << "Convolution auto-tuning disabled, CudnnConvAlgorithmPicker " + VLOG(2) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker " "returning early."; return false; } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h index 18d62a0c025..f02d7f2a80b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ #include "absl/time/time.h" #include "absl/types/optional.h" @@ -32,17 +32,17 @@ namespace gpu { // Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for // each and adding explicit scratch space to the CustomCalls. -class CudnnConvAlgorithmPicker : public HloModulePass { +class GpuConvAlgorithmPicker : public HloModulePass { public: // If the `allocator` parameter is not null, we will use it to allocate temp // memory while timing the various convolution algorithms. If it's null, // we'll use the default allocator on the StreamExecutor. - CudnnConvAlgorithmPicker(se::StreamExecutor* stream_exec, - se::DeviceMemoryAllocator* allocator) + GpuConvAlgorithmPicker(se::StreamExecutor* stream_exec, + se::DeviceMemoryAllocator* allocator) : stream_exec_(stream_exec), allocator_(allocator) {} absl::string_view name() const override { - return "cudnn-conv-algorithm-picker"; + return "gpu-conv-algorithm-picker"; } StatusOr Run(HloModule* module) override; @@ -52,8 +52,14 @@ class CudnnConvAlgorithmPicker : public HloModulePass { StatusOr RunOnInstruction(HloInstruction* instr); StatusOr PickBestAlgorithm( const HloCustomCallInstruction* instr); - StatusOr PickBestAlgorithmNoCache( - const HloCustomCallInstruction* instr); + + StatusOr PickBestAlgorithmNoCacheCuda( + const HloCustomCallInstruction& instr, + se::DeviceMemoryAllocator* allocator, se::Stream* stream); + + StatusOr PickBestAlgorithmNoCacheRocm( + const HloCustomCallInstruction& instr, + se::DeviceMemoryAllocator* allocator, se::Stream* stream); se::StreamExecutor* stream_exec_; // never null se::DeviceMemoryAllocator* allocator_; // may be null @@ -61,5 +67,4 @@ class CudnnConvAlgorithmPicker : public HloModulePass { } // namespace gpu } // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONV_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc index c1db7916b3c..261d43d5938 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_runner.cc @@ -223,7 +223,16 @@ Status RunGpuConvImpl(const GpuConvParams& params, auto output_buf = se::DeviceMemory(params.output_buf); AlgorithmConfig algorithm = params.algorithm; - if (options.algo_override) { + // in ROCm mode, the first call to run the convolution needs to trigger the + // code that calls miopenFind* API. That triggger is implicit, it is based + // on whether or not the AlgorithmConfig::algorithm is empty! So for the + // first call we need to ensure that the AlgorithmConfig::algorithm is + // empty. For all subsequent calls, we should use the value retrieved from + // the backend_config + if ((options.algo_override.has_value()) && + (*options.algo_override == se::dnn::AlgorithmDesc())) { + algorithm = AlgorithmConfig(); + } else if (options.algo_override.has_value()) { algorithm = AlgorithmConfig(*options.algo_override); } @@ -261,8 +270,13 @@ StatusOr GetGpuConvParams( const Shape* filter_shape; const Shape* output_shape; - params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc( - backend_config.algorithm(), backend_config.tensor_ops_enabled())); + // The third field is scratch size stored from conv_algorithm_picker + // The operand is added to the shape field of the conv instruction + // in GpuConvAlgorithmPicker::RunOnInstruction() call. + params.algorithm = se::dnn::AlgorithmConfig( + se::dnn::AlgorithmDesc(backend_config.algorithm(), + backend_config.tensor_ops_enabled()), + conv->shape().tuple_shapes(1).dimensions(0)); params.conv_result_scale = backend_config.conv_result_scale(); switch (params.kind) { diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 36e34e1ffce..489cbd101e2 100755 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -188,11 +188,11 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment( // The new tuple and gte instructions then be simplified away, because // nobody is expected to use the scratch value. // - // However, if we were to run CudnnConvAlgorithmPicker after fusion + // However, if we were to run GpuConvAlgorithmPicker after fusion // the gte(customcall, 0) would probably already be into a fusion node. We // can't simplify across HloComputation boundaries, so in this case we // wouldn't be able to simplify away the new_tuple bits. - pipeline.AddPass(stream_exec, device_allocator); + pipeline.AddPass(stream_exec, device_allocator); // Find the fastest algorithm for GEMMs. pipeline.AddPass(stream_exec, device_allocator);