From 98e4579b39edcdca373960c8d1379bde473f90cd Mon Sep 17 00:00:00 2001 From: jerryyin Date: Tue, 10 Sep 2019 16:37:16 +0000 Subject: [PATCH] Addressing review feedbacks --- .../service/gpu/gpu_conv_algorithm_picker.cc | 66 +++++++++---------- .../service/gpu/gpu_conv_algorithm_picker.h | 4 +- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc index 762faea3418..7da9862e95b 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc @@ -20,7 +20,6 @@ limitations under the License. #include "absl/strings/str_format.h" #include "absl/time/time.h" #include "absl/types/optional.h" -#include "google/protobuf/any.pb.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" @@ -306,9 +305,9 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithm( // have diverged. Secifically, we need to make sure redzone allocator related // utilities are not used in ROCm routine if (stream_exec_->platform_kind() == se::PlatformKind::kROCm) { - result_or = PickBestAlgorithmNoCacheRocm(*instr, allocator, stream); + result_or = PickBestAlgorithmNoCacheRocm(instr, allocator, stream); } else if (stream_exec_->platform_kind() == se::PlatformKind::kCuda) { - result_or = PickBestAlgorithmNoCacheCuda(*instr, allocator, stream); + result_or = PickBestAlgorithmNoCacheCuda(instr, allocator, stream); } if (result_or.ok()) { @@ -320,13 +319,13 @@ StatusOr GpuConvAlgorithmPicker::PickBestAlgorithm( StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( - const HloCustomCallInstruction& instr, se::DeviceMemoryAllocator* allocator, + const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator, se::Stream* stream) { // Right now Redzone allocator is available in Cuda target only XLA_SCOPED_LOGGING_TIMER(absl::StrCat( - "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr.ToString())); + "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString())); - const Shape& result_shape = instr.shape().tuple_shapes(0); + const Shape& result_shape = instr->shape().tuple_shapes(0); const auto device_ordinal = stream_exec_->device_ordinal(); int64 rng_state = 0; @@ -337,13 +336,13 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( InitializeBuffer(stream, buffer_shape.element_type(), &rng_state, buffer); }; - const HloModuleConfig& hlo_module_config = instr.GetModule()->config(); + const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); // Allocate space for the input, filter, and output of the convolution. se::RedzoneAllocator input_output_allocator( stream, allocator, PtxOptsFromConfig(hlo_module_config)); std::vector operand_buffers; - for (const auto* operand : instr.operands()) { + for (const auto* operand : instr->operands()) { TF_ASSIGN_OR_RETURN(auto buffer, input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(operand->shape()))); @@ -356,7 +355,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( initialize_buffer(result_buffer, result_shape); TF_ASSIGN_OR_RETURN(auto backend_config, - instr.backend_config()); + instr->backend_config()); optional comparator; // Use the first algorithm that's supported as reference. There isn't a @@ -365,17 +364,17 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( se::DeviceMemoryBase reference_result_buffer; AlgorithmDesc first_algorithm; - TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(&instr)); + TF_ASSIGN_OR_RETURN(CudnnConvKind kind, GetCudnnConvKind(instr)); std::vector profile_results; const DebugOptions& debug_options = - instr.GetModule()->config().debug_options(); + instr->GetModule()->config().debug_options(); const bool crash_on_checking_failure = debug_options.xla_gpu_crash_on_verification_failures(); const auto canonical_hlo = - std::get<1>(AutotuneCacheKeyfromInstruction(&instr, stream_exec_)); + std::get<1>(AutotuneCacheKeyfromInstruction(instr, stream_exec_)); string blas_version; if (auto* blas = stream_exec_->AsBlas()) { @@ -395,7 +394,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( if (absl::c_linear_search(blacklisted_algos, alg)) { LOG(INFO) << "Omitted potentially buggy algorithm " - << AlgorithmToString(alg) << " for conv " << instr.ToString(); + << AlgorithmToString(alg) << " for conv " << instr->ToString(); continue; } @@ -403,7 +402,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( stream, allocator, PtxOptsFromConfig(hlo_module_config)); se::dnn::ProfileResult profile_result; VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " - << instr.ToString(); + << instr->ToString(); // Use assignment instead of brace-list to make GCC 4.9 happy. RunConvOptions options; @@ -435,11 +434,11 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( // Check for writes to redzones. TF_ASSIGN_OR_RETURN(bool input_output_allocator_redzone_clear, CheckRedzones(input_output_allocator, stream, - "input/output", &instr, &result)); + "input/output", instr, &result)); TF_ASSIGN_OR_RETURN( bool scratch_allocator_redzone_clear, - CheckRedzones(scratch_allocator, stream, "scratch", &instr, &result)); + CheckRedzones(scratch_allocator, stream, "scratch", instr, &result)); if (!input_output_allocator_redzone_clear || !scratch_allocator_redzone_clear) { @@ -470,7 +469,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( if (!compare_result.ok()) { LOG(ERROR) << "Unable to compare " << AlgorithmToString(first_algorithm) << " against " << AlgorithmToString(alg) << " for " - << instr.ToString() << ": " << compare_result.status(); + << instr->ToString() << ": " << compare_result.status(); if (compare_result.status().code() == tensorflow::error::RESOURCE_EXHAUSTED) { // Possibly OOM. Propatate the error. @@ -481,11 +480,12 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( LOG(ERROR) << "Results mismatch between different convolution algorithms. " "This is likely a bug/unexpected loss of precision in cudnn.\n" - << instr.ToString() << " for " << AlgorithmToString(first_algorithm) - << " vs " << AlgorithmToString(alg); + << instr->ToString() << " for " + << AlgorithmToString(first_algorithm) << " vs " + << AlgorithmToString(alg); PrintPlatformInfo(stream); VLOG(1) << "Full module on failure: \n" - << instr.GetModule()->ToString(); + << instr->GetModule()->ToString(); auto* fail = result.mutable_failure(); fail->set_kind(AutotuneResult::WRONG_RESULT); fail->set_buffer_address( @@ -512,11 +512,11 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( tensorflow::AutotuningLog log; { ConvInstructionLog instr_log; - *instr_log.mutable_instruction() = instr.ToProto(); - for (int i = 0; i < instr.operand_count(); i++) { - *instr_log.add_operand_shapes() = instr.operand(i)->shape().ToProto(); + *instr_log.mutable_instruction() = instr->ToProto(); + for (int i = 0; i < instr->operand_count(); i++) { + *instr_log.add_operand_shapes() = instr->operand(i)->shape().ToProto(); instr_log.add_operand_addresses( - reinterpret_cast((operand_buffers)[i].opaque())); + reinterpret_cast(operand_buffers[i].opaque())); } instr_log.set_result_address( reinterpret_cast(result_buffer.opaque())); @@ -582,15 +582,15 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda( return InternalError( "All algorithms tried for convolution %s failed. Falling back to " "default algorithm.", - instr.ToString()); + instr->ToString()); } StatusOr GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( - const HloCustomCallInstruction& instr, se::DeviceMemoryAllocator* allocator, + const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator, se::Stream* stream) { XLA_SCOPED_LOGGING_TIMER(absl::StrCat( - "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr.ToString())); + "GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString())); const auto device_ordinal = stream_exec_->device_ordinal(); std::vector operand_buffers; @@ -607,7 +607,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( // Allocate space for the input, filter, and output of the convolution. We // use a ScratchAllocator for this instead of calling allocator_ directly so // that our allocations don't leak. - for (const auto* operand : instr.operands()) { + for (const auto* operand : instr->operands()) { TF_ASSIGN_OR_RETURN(auto buffer, input_output_allocator.AllocateBytes( ShapeUtil::ByteSizeOf(operand->shape()))); @@ -618,12 +618,12 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( TF_ASSIGN_OR_RETURN( auto result_buffer, input_output_allocator.AllocateBytes( - ShapeUtil::ByteSizeOf(instr.shape().tuple_shapes(0)))); + ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0)))); initialize_buffer(result_buffer); ScratchAllocator scratch_allocator(device_ordinal, allocator); se::dnn::ProfileResult profile_result; - VLOG(3) << "Auto-tuning for " << instr.ToString(); + VLOG(3) << "Auto-tuning for " << instr->ToString(); RunConvOptions options; options.profile_result = &profile_result; @@ -632,8 +632,8 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( options.algo_override = se::dnn::AlgorithmDesc(); bool launch_ok = - RunCudnnConv(&instr, absl::MakeSpan(operand_buffers), result_buffer, - &scratch_allocator, stream, options) + RunGpuConv(instr, absl::MakeSpan(operand_buffers), result_buffer, + &scratch_allocator, stream, options) .ok(); AutotuneResult best_result; @@ -653,7 +653,7 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm( return InternalError( "All algorithms tried for convolution %s failed. Falling back to " "default algorithm.", - instr.ToString()); + instr->ToString()); } StatusOr GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h index f02d7f2a80b..7b6ca6a8e2c 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h @@ -54,11 +54,11 @@ class GpuConvAlgorithmPicker : public HloModulePass { const HloCustomCallInstruction* instr); StatusOr PickBestAlgorithmNoCacheCuda( - const HloCustomCallInstruction& instr, + const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator, se::Stream* stream); StatusOr PickBestAlgorithmNoCacheRocm( - const HloCustomCallInstruction& instr, + const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator, se::Stream* stream); se::StreamExecutor* stream_exec_; // never null