Merge pull request #32364 from ROCmSoftwarePlatform:google-upstream-pr-conv_algorithm_picker
PiperOrigin-RevId: 272994459
This commit is contained in:
commit
c084d2860d
@ -236,9 +236,9 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":backend_configs",
|
":backend_configs",
|
||||||
":buffer_allocations",
|
":buffer_allocations",
|
||||||
":cudnn_conv_runner",
|
|
||||||
":elemental_ir_emitter",
|
":elemental_ir_emitter",
|
||||||
":gpu_constants",
|
":gpu_constants",
|
||||||
|
":gpu_conv_runner",
|
||||||
":gpu_executable",
|
":gpu_executable",
|
||||||
":hlo_to_ir_bindings",
|
":hlo_to_ir_bindings",
|
||||||
":ir_emission_utils",
|
":ir_emission_utils",
|
||||||
@ -505,7 +505,7 @@ cc_library(
|
|||||||
":backend_configs",
|
":backend_configs",
|
||||||
":buffer_allocations",
|
":buffer_allocations",
|
||||||
":cusolver_context",
|
":cusolver_context",
|
||||||
":cudnn_conv_runner",
|
":gpu_conv_runner",
|
||||||
":gpu_debug_info_manager",
|
":gpu_debug_info_manager",
|
||||||
":gpu_types",
|
":gpu_types",
|
||||||
":hlo_execution_profiler",
|
":hlo_execution_profiler",
|
||||||
@ -615,7 +615,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":backend_configs",
|
":backend_configs",
|
||||||
":buffer_comparator",
|
":buffer_comparator",
|
||||||
":cudnn_conv_runner",
|
":gpu_conv_runner",
|
||||||
":gpu_executable",
|
":gpu_executable",
|
||||||
":ir_emission_utils",
|
":ir_emission_utils",
|
||||||
":stream_executor_util",
|
":stream_executor_util",
|
||||||
@ -637,14 +637,14 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cudnn_conv_algorithm_picker",
|
name = "gpu_conv_algorithm_picker",
|
||||||
srcs = ["cudnn_conv_algorithm_picker.cc"],
|
srcs = ["gpu_conv_algorithm_picker.cc"],
|
||||||
hdrs = ["cudnn_conv_algorithm_picker.h"],
|
hdrs = ["gpu_conv_algorithm_picker.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":backend_configs",
|
":backend_configs",
|
||||||
":buffer_comparator",
|
":buffer_comparator",
|
||||||
":cudnn_conv_runner",
|
|
||||||
":gpu_autotuning_proto",
|
":gpu_autotuning_proto",
|
||||||
|
":gpu_conv_runner",
|
||||||
":gpu_executable",
|
":gpu_executable",
|
||||||
":hlo_algorithm_blacklist",
|
":hlo_algorithm_blacklist",
|
||||||
":ir_emission_utils",
|
":ir_emission_utils",
|
||||||
@ -672,7 +672,7 @@ cc_library(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cudnn_conv_runner",
|
name = "gpu_conv_runner",
|
||||||
srcs = ["gpu_conv_runner.cc"],
|
srcs = ["gpu_conv_runner.cc"],
|
||||||
hdrs = ["gpu_conv_runner.h"],
|
hdrs = ["gpu_conv_runner.h"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -1057,6 +1057,7 @@ cc_library(
|
|||||||
":cudnn_pad_for_convolutions",
|
":cudnn_pad_for_convolutions",
|
||||||
":fusion_merger",
|
":fusion_merger",
|
||||||
":gpu_constants",
|
":gpu_constants",
|
||||||
|
":gpu_conv_algorithm_picker",
|
||||||
":gpu_conv_padding_legalization",
|
":gpu_conv_padding_legalization",
|
||||||
":gpu_conv_rewriter",
|
":gpu_conv_rewriter",
|
||||||
":gpu_copy_insertion",
|
":gpu_copy_insertion",
|
||||||
@ -1156,13 +1157,13 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":cublas_gemm_pad_for_tensor_cores",
|
":cublas_gemm_pad_for_tensor_cores",
|
||||||
":cudnn_conv_algorithm_picker",
|
|
||||||
":cudnn_fused_conv_rewriter",
|
":cudnn_fused_conv_rewriter",
|
||||||
":cudnn_pad_for_convolutions",
|
":cudnn_pad_for_convolutions",
|
||||||
":cusolver_rewriter",
|
":cusolver_rewriter",
|
||||||
":gemm_algorithm_picker",
|
":gemm_algorithm_picker",
|
||||||
":gemm_rewriter",
|
":gemm_rewriter",
|
||||||
":gpu_compiler",
|
":gpu_compiler",
|
||||||
|
":gpu_conv_algorithm_picker",
|
||||||
":gpu_conv_padding_legalization",
|
":gpu_conv_padding_legalization",
|
||||||
":gpu_conv_rewriter",
|
":gpu_conv_rewriter",
|
||||||
":gpu_layout_assignment",
|
":gpu_layout_assignment",
|
||||||
|
@ -16,9 +16,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h"
|
#include "tensorflow/compiler/xla/service/gpu/amdgpu_compiler.h"
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
|
||||||
// TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged.
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
|
#include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
#include "tensorflow/compiler/xla/service/gpu/target_constants.h"
|
||||||
@ -97,7 +97,7 @@ Status AMDGPUCompiler::OptimizeHloPostLayoutAssignment(
|
|||||||
options.set_is_layout_sensitive(true);
|
options.set_is_layout_sensitive(true);
|
||||||
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
|
pipeline.AddPass<HloPassFix<AlgebraicSimplifier>>(options);
|
||||||
|
|
||||||
// TODO(whchung@gmail.com): Add gpu_conv_algorithm_picker after its PR merged.
|
pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
|
||||||
|
|
||||||
// Clean up new_tuple described above.
|
// Clean up new_tuple described above.
|
||||||
pipeline.AddPass<TupleSimplifier>();
|
pipeline.AddPass<TupleSimplifier>();
|
||||||
|
@ -13,9 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
|
||||||
|
|
||||||
#include "google/protobuf/any.pb.h"
|
|
||||||
#include "absl/algorithm/container.h"
|
#include "absl/algorithm/container.h"
|
||||||
#include "absl/strings/str_cat.h"
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/str_format.h"
|
#include "absl/strings/str_format.h"
|
||||||
@ -48,6 +47,54 @@ using se::DeviceMemoryBase;
|
|||||||
using se::dnn::AlgorithmDesc;
|
using se::dnn::AlgorithmDesc;
|
||||||
using tensorflow::AutotuneResult;
|
using tensorflow::AutotuneResult;
|
||||||
|
|
||||||
|
class ScratchAllocator : public se::ScratchAllocator {
|
||||||
|
public:
|
||||||
|
ScratchAllocator(int device_ordinal,
|
||||||
|
se::DeviceMemoryAllocator* memory_allocator)
|
||||||
|
: device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {}
|
||||||
|
|
||||||
|
int64 GetMemoryLimitInBytes() override {
|
||||||
|
return 1LL << 32; // 4GB. TODO(jlebar): Tune this?
|
||||||
|
}
|
||||||
|
int64 TotalAllocatedBytes() { return total_allocated_bytes_; }
|
||||||
|
|
||||||
|
StatusOr<se::DeviceMemory<uint8>> AllocateBytes(int64 byte_size) override;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
StatusOr<se::DeviceMemory<T>> Allocate(int64 num_elements) {
|
||||||
|
TF_ASSIGN_OR_RETURN(se::DeviceMemory<uint8> bytes,
|
||||||
|
AllocateBytes(num_elements * sizeof(T)));
|
||||||
|
return se::DeviceMemory<T>(bytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const int device_ordinal_;
|
||||||
|
se::DeviceMemoryAllocator* memory_allocator_;
|
||||||
|
std::vector<se::OwningDeviceMemory> allocated_buffers_;
|
||||||
|
int64 total_allocated_bytes_ = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
StatusOr<se::DeviceMemory<uint8>> ScratchAllocator::AllocateBytes(
|
||||||
|
int64 byte_size) {
|
||||||
|
CHECK_GE(byte_size, 0) << "byte_size must be positive.";
|
||||||
|
if (byte_size > GetMemoryLimitInBytes()) {
|
||||||
|
return se::port::Status(
|
||||||
|
se::port::error::RESOURCE_EXHAUSTED,
|
||||||
|
absl::StrFormat(
|
||||||
|
"Allocating %d bytes exceeds the memory limit of %d bytes.",
|
||||||
|
byte_size, GetMemoryLimitInBytes()));
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(se::OwningDeviceMemory allocated_buffer,
|
||||||
|
memory_allocator_->Allocate(device_ordinal_, byte_size,
|
||||||
|
/*retry_on_failure=*/false));
|
||||||
|
total_allocated_bytes_ += byte_size;
|
||||||
|
|
||||||
|
se::DeviceMemoryBase buffer_addr = *allocated_buffer;
|
||||||
|
allocated_buffers_.push_back(std::move(allocated_buffer));
|
||||||
|
return se::DeviceMemory<uint8>(buffer_addr);
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
|
std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
|
||||||
se::StreamExecutor* stream_exec) {
|
se::StreamExecutor* stream_exec) {
|
||||||
std::vector<AlgorithmDesc> algorithms;
|
std::vector<AlgorithmDesc> algorithms;
|
||||||
@ -198,7 +245,7 @@ auto& autotune_cache_stats GUARDED_BY(autotune_cache_lock) =
|
|||||||
*new ConvCacheStats();
|
*new ConvCacheStats();
|
||||||
} // anonymous namespace
|
} // anonymous namespace
|
||||||
|
|
||||||
StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
|
StatusOr<AutotuneResult> GpuConvAlgorithmPicker::PickBestAlgorithm(
|
||||||
const HloCustomCallInstruction* instr) {
|
const HloCustomCallInstruction* instr) {
|
||||||
// Don't run this function concurrently on the same GPU.
|
// Don't run this function concurrently on the same GPU.
|
||||||
//
|
//
|
||||||
@ -226,22 +273,6 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithm(
|
|||||||
autotune_cache_stats.cache_misses++;
|
autotune_cache_stats.cache_misses++;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<AutotuneResult> result_or = PickBestAlgorithmNoCache(instr);
|
|
||||||
if (result_or.ok()) {
|
|
||||||
tensorflow::mutex_lock lock(autotune_cache_lock);
|
|
||||||
CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second);
|
|
||||||
}
|
|
||||||
return result_or;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
|
||||||
const HloCustomCallInstruction* instr) {
|
|
||||||
XLA_SCOPED_LOGGING_TIMER(
|
|
||||||
absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithmImpl for ",
|
|
||||||
instr->ToString()));
|
|
||||||
|
|
||||||
const Shape& result_shape = instr->shape().tuple_shapes(0);
|
|
||||||
|
|
||||||
// Make sure any previous activity on this executor is done. We don't want to
|
// Make sure any previous activity on this executor is done. We don't want to
|
||||||
// interfere with programs that are still running on the GPU.
|
// interfere with programs that are still running on the GPU.
|
||||||
if (!stream_exec_->SynchronizeAllActivity()) {
|
if (!stream_exec_->SynchronizeAllActivity()) {
|
||||||
@ -269,6 +300,32 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
|||||||
return &stream_opt.value();
|
return &stream_opt.value();
|
||||||
}();
|
}();
|
||||||
|
|
||||||
|
StatusOr<AutotuneResult> result_or(InternalError("Unknown platform."));
|
||||||
|
// Check StreamExecutor on which platform it is. ROCm and Cuda implementation
|
||||||
|
// have diverged. Secifically, we need to make sure redzone allocator related
|
||||||
|
// utilities are not used in ROCm routine
|
||||||
|
if (stream_exec_->platform_kind() == se::PlatformKind::kROCm) {
|
||||||
|
result_or = PickBestAlgorithmNoCacheRocm(instr, allocator, stream);
|
||||||
|
} else if (stream_exec_->platform_kind() == se::PlatformKind::kCuda) {
|
||||||
|
result_or = PickBestAlgorithmNoCacheCuda(instr, allocator, stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result_or.ok()) {
|
||||||
|
tensorflow::mutex_lock lock(autotune_cache_lock);
|
||||||
|
CHECK(autotune_cache.insert({key, result_or.ValueOrDie()}).second);
|
||||||
|
}
|
||||||
|
return result_or;
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<tensorflow::AutotuneResult>
|
||||||
|
GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheCuda(
|
||||||
|
const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator,
|
||||||
|
se::Stream* stream) {
|
||||||
|
// Right now Redzone allocator is available in Cuda target only
|
||||||
|
XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
|
||||||
|
"GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
|
||||||
|
|
||||||
|
const Shape& result_shape = instr->shape().tuple_shapes(0);
|
||||||
int64 rng_state = 0;
|
int64 rng_state = 0;
|
||||||
|
|
||||||
const auto initialize_buffer = [&stream, &rng_state](
|
const auto initialize_buffer = [&stream, &rng_state](
|
||||||
@ -526,8 +583,78 @@ StatusOr<AutotuneResult> CudnnConvAlgorithmPicker::PickBestAlgorithmNoCache(
|
|||||||
instr->ToString());
|
instr->ToString());
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<bool> CudnnConvAlgorithmPicker::RunOnInstruction(
|
StatusOr<tensorflow::AutotuneResult>
|
||||||
HloInstruction* instr) {
|
GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
|
||||||
|
const HloCustomCallInstruction* instr, se::DeviceMemoryAllocator* allocator,
|
||||||
|
se::Stream* stream) {
|
||||||
|
XLA_SCOPED_LOGGING_TIMER(absl::StrCat(
|
||||||
|
"GpuConvAlgorithmPicker::PickBestAlgorithmImpl for ", instr->ToString()));
|
||||||
|
|
||||||
|
const auto device_ordinal = stream_exec_->device_ordinal();
|
||||||
|
std::vector<se::DeviceMemoryBase> operand_buffers;
|
||||||
|
|
||||||
|
ScratchAllocator input_output_allocator(device_ordinal, allocator);
|
||||||
|
const auto initialize_buffer = [stream](DeviceMemoryBase buffer) {
|
||||||
|
// Although we don't have evidence this matters, zero out the buffers
|
||||||
|
// before autotuning. It's conceivable that using uninitialized memory as
|
||||||
|
// the inputs might affect performance if e.g. the inputs contain
|
||||||
|
// denormals, and this is easy enough.
|
||||||
|
stream->ThenMemZero(&buffer, buffer.size());
|
||||||
|
};
|
||||||
|
|
||||||
|
// Allocate space for the input, filter, and output of the convolution. We
|
||||||
|
// use a ScratchAllocator for this instead of calling allocator_ directly so
|
||||||
|
// that our allocations don't leak.
|
||||||
|
for (const auto* operand : instr->operands()) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto buffer,
|
||||||
|
input_output_allocator.AllocateBytes(
|
||||||
|
ShapeUtil::ByteSizeOf(operand->shape())));
|
||||||
|
initialize_buffer(buffer);
|
||||||
|
operand_buffers.push_back(buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
auto result_buffer,
|
||||||
|
input_output_allocator.AllocateBytes(
|
||||||
|
ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
|
||||||
|
initialize_buffer(result_buffer);
|
||||||
|
|
||||||
|
ScratchAllocator scratch_allocator(device_ordinal, allocator);
|
||||||
|
se::dnn::ProfileResult profile_result;
|
||||||
|
VLOG(3) << "Auto-tuning for " << instr->ToString();
|
||||||
|
RunConvOptions options;
|
||||||
|
options.profile_result = &profile_result;
|
||||||
|
|
||||||
|
// ROCm: Set the overriding algorithm to empty to remind cudnn_conv_runner
|
||||||
|
// that the AlgorithmConfig in running convolution needs to be empty
|
||||||
|
options.algo_override = se::dnn::AlgorithmDesc();
|
||||||
|
|
||||||
|
bool launch_ok =
|
||||||
|
RunGpuConv(instr, absl::MakeSpan(operand_buffers), result_buffer,
|
||||||
|
&scratch_allocator, stream, options)
|
||||||
|
.ok();
|
||||||
|
|
||||||
|
AutotuneResult best_result;
|
||||||
|
if (launch_ok && profile_result.is_valid()) {
|
||||||
|
best_result.mutable_conv()->set_algorithm(
|
||||||
|
profile_result.algorithm().algo_id());
|
||||||
|
best_result.mutable_conv()->set_tensor_ops_enabled(
|
||||||
|
profile_result.algorithm().tensor_ops_enabled());
|
||||||
|
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
|
||||||
|
best_result.set_scratch_bytes(scratch_bytes_used);
|
||||||
|
*best_result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
|
||||||
|
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||||
|
|
||||||
|
return best_result;
|
||||||
|
}
|
||||||
|
|
||||||
|
return InternalError(
|
||||||
|
"All algorithms tried for convolution %s failed. Falling back to "
|
||||||
|
"default algorithm.",
|
||||||
|
instr->ToString());
|
||||||
|
}
|
||||||
|
|
||||||
|
StatusOr<bool> GpuConvAlgorithmPicker::RunOnInstruction(HloInstruction* instr) {
|
||||||
CHECK(IsCustomCallToDnnConvolution(*instr));
|
CHECK(IsCustomCallToDnnConvolution(*instr));
|
||||||
|
|
||||||
StatusOr<AutotuneResult> best_algo_or =
|
StatusOr<AutotuneResult> best_algo_or =
|
||||||
@ -577,7 +704,7 @@ StatusOr<bool> CudnnConvAlgorithmPicker::RunOnInstruction(
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<bool> CudnnConvAlgorithmPicker::RunOnComputation(
|
StatusOr<bool> GpuConvAlgorithmPicker::RunOnComputation(
|
||||||
HloComputation* computation) {
|
HloComputation* computation) {
|
||||||
std::vector<HloInstruction*> convs;
|
std::vector<HloInstruction*> convs;
|
||||||
for (auto* instr : computation->instructions()) {
|
for (auto* instr : computation->instructions()) {
|
||||||
@ -594,11 +721,11 @@ StatusOr<bool> CudnnConvAlgorithmPicker::RunOnComputation(
|
|||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<bool> CudnnConvAlgorithmPicker::Run(HloModule* module) {
|
StatusOr<bool> GpuConvAlgorithmPicker::Run(HloModule* module) {
|
||||||
XLA_SCOPED_LOGGING_TIMER("CudnnConvAlgorithmPicker");
|
XLA_SCOPED_LOGGING_TIMER("GpuConvAlgorithmPicker");
|
||||||
|
|
||||||
if (module->config().debug_options().xla_gpu_disable_autotune()) {
|
if (module->config().debug_options().xla_gpu_disable_autotune()) {
|
||||||
VLOG(2) << "Convolution auto-tuning disabled, CudnnConvAlgorithmPicker "
|
VLOG(2) << "Convolution auto-tuning disabled, GpuConvAlgorithmPicker "
|
||||||
"returning early.";
|
"returning early.";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
|
|
||||||
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_
|
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_
|
||||||
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_
|
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_
|
||||||
|
|
||||||
#include "absl/time/time.h"
|
#include "absl/time/time.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
@ -32,17 +32,17 @@ namespace gpu {
|
|||||||
|
|
||||||
// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
|
// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for
|
||||||
// each and adding explicit scratch space to the CustomCalls.
|
// each and adding explicit scratch space to the CustomCalls.
|
||||||
class CudnnConvAlgorithmPicker : public HloModulePass {
|
class GpuConvAlgorithmPicker : public HloModulePass {
|
||||||
public:
|
public:
|
||||||
// If the `allocator` parameter is not null, we will use it to allocate temp
|
// If the `allocator` parameter is not null, we will use it to allocate temp
|
||||||
// memory while timing the various convolution algorithms. If it's null,
|
// memory while timing the various convolution algorithms. If it's null,
|
||||||
// we'll use the default allocator on the StreamExecutor.
|
// we'll use the default allocator on the StreamExecutor.
|
||||||
CudnnConvAlgorithmPicker(se::StreamExecutor* stream_exec,
|
GpuConvAlgorithmPicker(se::StreamExecutor* stream_exec,
|
||||||
se::DeviceMemoryAllocator* allocator)
|
se::DeviceMemoryAllocator* allocator)
|
||||||
: stream_exec_(stream_exec), allocator_(allocator) {}
|
: stream_exec_(stream_exec), allocator_(allocator) {}
|
||||||
|
|
||||||
absl::string_view name() const override {
|
absl::string_view name() const override {
|
||||||
return "cudnn-conv-algorithm-picker";
|
return "gpu-conv-algorithm-picker";
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<bool> Run(HloModule* module) override;
|
StatusOr<bool> Run(HloModule* module) override;
|
||||||
@ -52,14 +52,19 @@ class CudnnConvAlgorithmPicker : public HloModulePass {
|
|||||||
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
|
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
|
||||||
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithm(
|
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithm(
|
||||||
const HloCustomCallInstruction* instr);
|
const HloCustomCallInstruction* instr);
|
||||||
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithmNoCache(
|
|
||||||
const HloCustomCallInstruction* instr);
|
|
||||||
|
|
||||||
se::StreamExecutor* stream_exec_; // never null
|
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithmNoCacheCuda(
|
||||||
se::DeviceMemoryAllocator* allocator_; // may be null
|
const HloCustomCallInstruction* instr,
|
||||||
|
se::DeviceMemoryAllocator* allocator, se::Stream* stream);
|
||||||
|
|
||||||
|
StatusOr<tensorflow::AutotuneResult> PickBestAlgorithmNoCacheRocm(
|
||||||
|
const HloCustomCallInstruction* instr,
|
||||||
|
se::DeviceMemoryAllocator* allocator, se::Stream* stream);
|
||||||
|
|
||||||
|
se::StreamExecutor* stream_exec_; // never null
|
||||||
|
se::DeviceMemoryAllocator* allocator_; // may be null
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gpu
|
} // namespace gpu
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONV_ALGORITHM_PICKER_H_
|
||||||
#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONV_ALGORITHM_PICKER_H_
|
|
@ -223,7 +223,17 @@ Status RunGpuConvImpl(const GpuConvParams& params,
|
|||||||
auto output_buf = se::DeviceMemory<OutputType>(params.output_buf);
|
auto output_buf = se::DeviceMemory<OutputType>(params.output_buf);
|
||||||
AlgorithmConfig algorithm = params.algorithm;
|
AlgorithmConfig algorithm = params.algorithm;
|
||||||
|
|
||||||
if (options.algo_override) {
|
// in ROCm mode, the first call to run the convolution needs to trigger the
|
||||||
|
// code that calls miopenFind* API. That triggger is implicit, it is based
|
||||||
|
// on whether or not the AlgorithmConfig::algorithm is empty! So for the
|
||||||
|
// first call we need to ensure that the AlgorithmConfig::algorithm is
|
||||||
|
// empty. For all subsequent calls, we should use the value retrieved from
|
||||||
|
// the backend_config
|
||||||
|
if ((stream->parent()->platform_kind() == se::PlatformKind::kROCm) &&
|
||||||
|
(options.algo_override.has_value()) &&
|
||||||
|
(*options.algo_override == se::dnn::AlgorithmDesc())) {
|
||||||
|
algorithm = AlgorithmConfig();
|
||||||
|
} else if (options.algo_override.has_value()) {
|
||||||
algorithm = AlgorithmConfig(*options.algo_override);
|
algorithm = AlgorithmConfig(*options.algo_override);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -261,8 +271,13 @@ StatusOr<GpuConvParams> GetGpuConvParams(
|
|||||||
const Shape* filter_shape;
|
const Shape* filter_shape;
|
||||||
const Shape* output_shape;
|
const Shape* output_shape;
|
||||||
|
|
||||||
params.algorithm = se::dnn::AlgorithmConfig(se::dnn::AlgorithmDesc(
|
// The third field is scratch size stored from conv_algorithm_picker
|
||||||
backend_config.algorithm(), backend_config.tensor_ops_enabled()));
|
// The operand is added to the shape field of the conv instruction
|
||||||
|
// in GpuConvAlgorithmPicker::RunOnInstruction() call.
|
||||||
|
params.algorithm = se::dnn::AlgorithmConfig(
|
||||||
|
se::dnn::AlgorithmDesc(backend_config.algorithm(),
|
||||||
|
backend_config.tensor_ops_enabled()),
|
||||||
|
conv->shape().tuple_shapes(1).dimensions(0));
|
||||||
params.conv_result_scale = backend_config.conv_result_scale();
|
params.conv_result_scale = backend_config.conv_result_scale();
|
||||||
|
|
||||||
switch (params.kind) {
|
switch (params.kind) {
|
||||||
|
@ -22,12 +22,12 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||||
#include "tensorflow/compiler/xla/service/dump.h"
|
#include "tensorflow/compiler/xla/service/dump.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h"
|
#include "tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
|
#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_conv_rewriter.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h"
|
#include "tensorflow/compiler/xla/service/gpu/cudnn_pad_for_convolutions.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
|
#include "tensorflow/compiler/xla/service/gpu/cusolver_rewriter.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h"
|
#include "tensorflow/compiler/xla/service/gpu/gemm_algorithm_picker.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
|
#include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_padding_legalization.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_conv_rewriter.h"
|
||||||
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
|
#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h"
|
||||||
@ -188,11 +188,11 @@ Status NVPTXCompiler::OptimizeHloPostLayoutAssignment(
|
|||||||
// The new tuple and gte instructions then be simplified away, because
|
// The new tuple and gte instructions then be simplified away, because
|
||||||
// nobody is expected to use the scratch value.
|
// nobody is expected to use the scratch value.
|
||||||
//
|
//
|
||||||
// However, if we were to run CudnnConvAlgorithmPicker after fusion
|
// However, if we were to run GpuConvAlgorithmPicker after fusion
|
||||||
// the gte(customcall, 0) would probably already be into a fusion node. We
|
// the gte(customcall, 0) would probably already be into a fusion node. We
|
||||||
// can't simplify across HloComputation boundaries, so in this case we
|
// can't simplify across HloComputation boundaries, so in this case we
|
||||||
// wouldn't be able to simplify away the new_tuple bits.
|
// wouldn't be able to simplify away the new_tuple bits.
|
||||||
pipeline.AddPass<CudnnConvAlgorithmPicker>(stream_exec, device_allocator);
|
pipeline.AddPass<GpuConvAlgorithmPicker>(stream_exec, device_allocator);
|
||||||
|
|
||||||
// Find the fastest algorithm for GEMMs.
|
// Find the fastest algorithm for GEMMs.
|
||||||
pipeline.AddPass<GemmAlgorithmPicker>(stream_exec, device_allocator);
|
pipeline.AddPass<GemmAlgorithmPicker>(stream_exec, device_allocator);
|
||||||
|
Loading…
Reference in New Issue
Block a user