diff --git a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc index 939b41382ac..8316cb7d12d 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.cc @@ -58,15 +58,48 @@ static int64 cache_misses TF_GUARDED_BY(autotune_cache_mu) = 0; // than sm_50 -- in both cases, cublas doesn't support gemm-with-algorithm at // all. static StatusOr> DoUncachedGemmAutotune( - const HloInstruction* gemm, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase reference_result_buffer, se::Stream* stream, - const se::RedzoneAllocator& allocator, const BufferComparator& comparator, - bool crash_on_checking_failure) { + const HloInstruction* gemm, se::Stream* stream, + se::DeviceMemoryAllocator* allocator) { if (!stream->parent()->SynchronizeAllActivity()) { return InternalError("Failed to synchronize GPU for autotuning."); } + const HloModuleConfig& hlo_module_config = gemm->GetModule()->config(); + const bool init_cublas_data = + hlo_module_config.debug_options().xla_gpu_autotune_level() > 1; + se::RedzoneAllocator input_output_allocator( + stream, allocator, PtxOptsFromConfig(hlo_module_config), + /*memory_limit=*/std::numeric_limits::max()); + + BufferComparator comparator(gemm->shape(), hlo_module_config); + + int64 rng_state = 0; + auto get_initialized_buffer = + [&](const HloInstruction* op) -> StatusOr { + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, + input_output_allocator.AllocateBytes( + ShapeUtil::ByteSizeOf(op->shape()))); + if (init_cublas_data) { + InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer); + } + return buffer; + }; + + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer, + get_initialized_buffer(gemm->operand(0))); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer, + get_initialized_buffer(gemm->operand(1))); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer, + get_initialized_buffer(gemm)); + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase reference_result_buffer, + get_initialized_buffer(gemm)); + + const DebugOptions& debug_options = + gemm->GetModule()->config().debug_options(); + + const bool crash_on_checking_failure = + debug_options.xla_gpu_crash_on_verification_failures(); + GemmBackendConfig backend_config = gemm->backend_config().ValueOrDie(); const int32 cublas_autotune_level = @@ -124,7 +157,7 @@ static StatusOr> DoUncachedGemmAutotune( TF_ASSIGN_OR_RETURN( se::RedzoneAllocator::RedzoneCheckStatus rz_check_status, - allocator.CheckRedzones()); + input_output_allocator.CheckRedzones()); if (!rz_check_status.ok()) { result.mutable_failure()->set_kind(AutotuneResult::REDZONE_MODIFIED); *result.mutable_failure()->mutable_msg() = @@ -194,17 +227,14 @@ static StatusOr> DoUncachedGemmAutotune( } static StatusOr> DoGemmAutotune( - const HloInstruction* instr, const HloInstruction* lhs, - const HloInstruction* rhs, se::DeviceMemoryBase lhs_buffer, - se::DeviceMemoryBase rhs_buffer, se::DeviceMemoryBase output_buffer, - se::DeviceMemoryBase reference_result_buffer, se::Stream* stream, - bool crash_on_checking_failure, const se::RedzoneAllocator& allocator, - const BufferComparator& comparator) { + const HloInstruction* instr, const GemmBackendConfig& gemm_config, + se::DeviceMemoryAllocator* allocator, se::Stream* stream) { + const HloInstruction* lhs = instr->operand(0); + const HloInstruction* rhs = instr->operand(1); + // Don't run autotuning concurrently on the same GPU. tensorflow::mutex_lock gpu_lock = LockGpu(stream->parent()); - GemmBackendConfig gemm_config = - instr->backend_config().ValueOrDie(); GemmCacheKey key = std::make_tuple(stream->parent(), lhs->shape(), rhs->shape(), @@ -235,11 +265,8 @@ static StatusOr> DoGemmAutotune( VLOG(2) << "Batch size is non-singular, using generic algorithm"; result = absl::nullopt; } else { - TF_ASSIGN_OR_RETURN( - result, - DoUncachedGemmAutotune(instr, lhs_buffer, rhs_buffer, output_buffer, - reference_result_buffer, stream, allocator, - comparator, crash_on_checking_failure)); + TF_ASSIGN_OR_RETURN(result, + DoUncachedGemmAutotune(instr, stream, allocator)); } CHECK(autotune_cache.emplace(key, result).second); @@ -255,52 +282,11 @@ static StatusOr RunOnInstruction(HloInstruction* instr, TF_ASSIGN_OR_RETURN(se::Stream* const stream, allocator->GetStream(executor->device_ordinal())); - const HloModuleConfig& hlo_module_config = instr->GetModule()->config(); - const bool init_cublas_data = - hlo_module_config.debug_options().xla_gpu_autotune_level() > 1; - se::RedzoneAllocator input_output_allocator( - stream, allocator, PtxOptsFromConfig(hlo_module_config), - /*memory_limit=*/std::numeric_limits::max()); - - BufferComparator comparator(instr->shape(), hlo_module_config); - - int64 rng_state = 0; - auto get_initialized_buffer = - [&](const HloInstruction* op) -> StatusOr { - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase buffer, - input_output_allocator.AllocateBytes( - ShapeUtil::ByteSizeOf(op->shape()))); - if (init_cublas_data) { - InitializeBuffer(stream, op->shape().element_type(), &rng_state, buffer); - } - return buffer; - }; - GemmBackendConfig gemm_config = instr->backend_config().ValueOrDie(); - const HloInstruction* lhs = instr->operand(0); - const HloInstruction* rhs = instr->operand(1); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase lhs_buffer, - get_initialized_buffer(lhs)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase rhs_buffer, - get_initialized_buffer(rhs)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase output_buffer, - get_initialized_buffer(instr)); - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase reference_result_buffer, - get_initialized_buffer(instr)); - - const DebugOptions& debug_options = - instr->GetModule()->config().debug_options(); - - const bool crash_on_checking_failure = - debug_options.xla_gpu_crash_on_verification_failures(); - - TF_ASSIGN_OR_RETURN( - absl::optional gemm_algorithm, - DoGemmAutotune(instr, lhs, rhs, lhs_buffer, rhs_buffer, output_buffer, - reference_result_buffer, stream, crash_on_checking_failure, - input_output_allocator, comparator)); + TF_ASSIGN_OR_RETURN(absl::optional gemm_algorithm, + DoGemmAutotune(instr, gemm_config, allocator, stream)); // We update instruction->backend_config(); if no algorithms are supported, // a different API is used, which does not require specifying an algorithm.