Merge pull request #35503 from ROCmSoftwarePlatform:google_upstream_rocm_miopen_immediate_mode

PiperOrigin-RevId: 289053613
Change-Id: I233d95adc3aa888460bd39a07fd7e168fea14846
This commit is contained in:
TensorFlower Gardener 2020-01-10 01:43:54 -08:00
commit c1971ab97c
18 changed files with 1002 additions and 381 deletions

View File

@ -593,6 +593,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"@com_google_absl//absl/algorithm:container",
"@llvm-project//llvm:core",
],

View File

@ -117,6 +117,29 @@ std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
return algorithms;
}
StatusOr<std::vector<se::dnn::ProfileResult>> GetAlgorithms(
const HloCustomCallInstruction* conv,
absl::Span<se::DeviceMemoryBase> operand_buffers,
se::DeviceMemoryBase result_buffer, se::StreamExecutor* stream_exec,
se::Stream* stream) {
std::vector<se::dnn::ProfileResult> algorithms;
TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
GetDnnConvolutionKind(conv));
TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype, GetDnnDataType(conv));
TF_ASSIGN_OR_RETURN(GpuConvParams params,
GetGpuConvParams(conv, operand_buffers, result_buffer));
bool succ = stream_exec->GetMIOpenConvolveAlgorithms(
kind, stream, dtype, params.input_descriptor, params.filter_descriptor,
params.conv_desc, params.output_descriptor, &algorithms);
DCHECK(succ);
return algorithms;
}
string AlgorithmToString(const AlgorithmDesc& algo) {
if (algo.tensor_ops_enabled()) {
return absl::StrCat(algo.algo_id(), "+TC");
@ -611,33 +634,72 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
initialize_buffer(result_buffer);
ScratchAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Auto-tuning for " << instr->ToString();
RunConvOptions options;
options.profile_result = &profile_result;
TF_ASSIGN_OR_RETURN(std::vector<se::dnn::ProfileResult> algorithms,
GetAlgorithms(instr, absl::MakeSpan(operand_buffers),
result_buffer, stream_exec_, stream));
// ROCm: Set the overriding algorithm to empty to remind cudnn_conv_runner
// that the AlgorithmConfig in running convolution needs to be empty
options.algo_override = se::dnn::AlgorithmDesc();
std::vector<AutotuneResult> profile_results;
bool launch_ok =
RunGpuConv(instr, absl::MakeSpan(operand_buffers), result_buffer,
&scratch_allocator, stream, options)
.ok();
AutotuneResult best_result;
if (launch_ok && profile_result.is_valid()) {
best_result.mutable_conv()->set_algorithm(
profile_result.algorithm().algo_id());
best_result.mutable_conv()->set_tensor_ops_enabled(
if (algorithms.size() == 1) {
auto profile_result = algorithms[0];
profile_results.emplace_back();
auto& result = profile_results.back();
result.mutable_conv()->set_algorithm(profile_result.algorithm().algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_result.algorithm().tensor_ops_enabled());
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
best_result.set_scratch_bytes(scratch_bytes_used);
*best_result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
return best_result;
result.set_scratch_bytes(profile_result.scratch_size());
*result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
} else {
for (const auto& miopen_alg : algorithms) {
const auto& alg = miopen_alg.algorithm();
XLA_SCOPED_LOGGING_TIMER_LEVEL(
absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
AlgorithmToString(alg)),
2);
ScratchAllocator scratch_allocator(device_ordinal, allocator);
se::dnn::ProfileResult profile_result;
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
<< instr->ToString();
// Use assignment instead of brace-list to make GCC 4.9 happy.
RunConvOptions options;
options.profile_result = &profile_result;
options.algo_override = alg;
Status launch_status =
RunGpuConv(instr, absl::MakeSpan(operand_buffers), result_buffer,
&scratch_allocator, stream, options);
if (!launch_status.ok()) {
continue;
}
if (!profile_result.is_valid()) {
continue;
}
profile_results.emplace_back();
AutotuneResult& result = profile_results.back();
result.mutable_conv()->set_algorithm(alg.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled());
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
result.set_scratch_bytes(scratch_bytes_used);
*result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
}
}
const auto& best_result = absl::c_min_element(
profile_results,
[&](const AutotuneResult& lhs, const AutotuneResult& rhs) {
return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
tensorflow::proto_utils::FromDurationProto(rhs.run_time());
});
if (best_result != profile_results.end()) {
return *best_result;
}
return InternalError(

View File

@ -223,17 +223,7 @@ Status RunGpuConvImpl(const GpuConvParams& params,
auto output_buf = se::DeviceMemory<OutputType>(params.output_buf);
AlgorithmConfig algorithm = params.algorithm;
// in ROCm mode, the first call to run the convolution needs to trigger the
// code that calls miopenFind* API. That triggger is implicit, it is based
// on whether or not the AlgorithmConfig::algorithm is empty! So for the
// first call we need to ensure that the AlgorithmConfig::algorithm is
// empty. For all subsequent calls, we should use the value retrieved from
// the backend_config
if ((stream->parent()->platform_kind() == se::PlatformKind::kROCm) &&
(options.algo_override.has_value()) &&
(*options.algo_override == se::dnn::AlgorithmDesc())) {
algorithm = AlgorithmConfig();
} else if (options.algo_override.has_value()) {
if (options.algo_override.has_value()) {
algorithm = AlgorithmConfig(*options.algo_override);
}

View File

@ -427,6 +427,39 @@ StatusOr<CudnnConvKind> GetCudnnConvKind(
return InternalError("Unexpected call target: %s", target);
}
StatusOr<se::dnn::ConvolutionKind> GetDnnConvolutionKind(
const HloCustomCallInstruction* instr) {
absl::string_view target = instr->custom_call_target();
if (target == kCudnnConvForwardCallTarget) {
return se::dnn::ConvolutionKind::FORWARD;
}
if (target == kCudnnConvBackwardInputCallTarget) {
return se::dnn::ConvolutionKind::BACKWARD_DATA;
}
if (target == kCudnnConvBackwardFilterCallTarget) {
return se::dnn::ConvolutionKind::BACKWARD_FILTER;
}
return InternalError("Unexpected call target: %s", target);
}
StatusOr<se::dnn::DataType> GetDnnDataType(
const HloCustomCallInstruction* conv) {
PrimitiveType output_primitive_type =
conv->shape().tuple_shapes(0).element_type();
switch (output_primitive_type) {
case F16:
return se::dnn::ToDataType<Eigen::half>::value;
case F32:
return se::dnn::ToDataType<float>::value;
case F64:
return se::dnn::ToDataType<double>::value;
default:
break;
}
return InternalError("Unsupported convolution datatype : %s",
conv->ToString());
}
string CudnnConvKindToString(CudnnConvKind kind) {
switch (kind) {
case CudnnConvKind::kForward:

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "llvm/IR/Value.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
// TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
// don't belong in "ir_emission_utils".
@ -53,6 +54,12 @@ enum class CudnnConvKind {
StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr);
StatusOr<se::dnn::ConvolutionKind> GetDnnConvolutionKind(
const HloCustomCallInstruction* instr);
StatusOr<se::dnn::DataType> GetDnnDataType(
const HloCustomCallInstruction* conv);
// Converts a CudnnConvKind value to a string.
string CudnnConvKindToString(CudnnConvKind kind);

View File

@ -20,10 +20,6 @@ load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl")
load("//tensorflow:tensorflow.bzl", "if_nccl")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests")
load(
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
)
load(
"//tensorflow/core/platform:build_config.bzl",
"tf_kernel_tests_linkstatic",
@ -522,7 +518,7 @@ cc_library(
tf_cuda_library(
name = "gpu_utils",
srcs = if_cuda_is_configured(["gpu_utils.cc"]),
srcs = if_cuda_or_rocm(["gpu_utils.cc"]),
hdrs = ["gpu_utils.h"],
deps = [
":gpu_util_hdrs",

View File

@ -1033,28 +1033,66 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
CheckRedzones(rz_allocator, &result);
}
}
#elif TENSORFLOW_USE_ROCM
std::vector<ProfileResult> algorithms;
OP_REQUIRES(ctx,
stream->parent()->GetMIOpenConvolveAlgorithms(
se::dnn::ConvolutionKind::BACKWARD_FILTER, stream,
se::dnn::ToDataType<T>::value, input_desc, filter_desc,
conv_desc, output_desc, &algorithms),
errors::Unknown(
"Failed to get convolution algorithm. This is probably "
"because MIOpen failed to initialize, so try looking to "
"see if a warning log message was printed above."));
std::vector<tensorflow::AutotuneResult> results;
if (algorithms.size() == 1) {
auto profile_result = algorithms[0];
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(
profile_result.algorithm().algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_result.algorithm().tensor_ops_enabled());
result.set_scratch_bytes(profile_result.scratch_size());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
} else {
for (auto miopen_algorithm : algorithms) {
auto profile_algorithm = miopen_algorithm.algorithm();
DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
ctx);
ProfileResult profile_result;
bool miopen_launch_status = true;
miopen_launch_status =
stream
->ThenConvolveBackwardFilterWithAlgorithm(
input_desc, input_ptr, output_desc, out_backprop_ptr,
conv_desc, filter_desc, &filter_backprop_ptr,
&scratch_allocator, AlgorithmConfig(profile_algorithm),
&profile_result)
.ok();
if (miopen_launch_status && profile_result.is_valid()) {
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
}
}
}
#endif
LogConvAutotuneResults(se::dnn::ConvolutionKind::BACKWARD_FILTER,
se::dnn::ToDataType<T>::value, input_ptr,
filter_backprop_ptr_rz, out_backprop_ptr, input_desc,
filter_backprop_ptr, out_backprop_ptr, input_desc,
filter_desc, output_desc, conv_desc,
stream->parent(), results);
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
#elif TENSORFLOW_USE_ROCM
ProfileResult best_result;
DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
ctx);
bool miopen_find_status =
stream
->ThenConvolveBackwardFilterWithAlgorithm(
input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
filter_desc, &filter_backprop_ptr, &scratch_allocator,
AlgorithmConfig(), &best_result)
.ok();
OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(),
errors::NotFound("Failed to find backward filter algorithm!"));
algorithm_config.set_algorithm(best_result.algorithm());
algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters,
algorithm_config);
}

View File

@ -1199,29 +1199,64 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
CheckRedzones(rz_allocator, &result);
}
}
#elif TENSORFLOW_USE_ROCM
std::vector<ProfileResult> algorithms;
OP_REQUIRES(ctx,
stream->parent()->GetMIOpenConvolveAlgorithms(
se::dnn::ConvolutionKind::BACKWARD_DATA, stream,
se::dnn::ToDataType<T>::value, input_desc, filter_desc,
conv_desc, output_desc, &algorithms),
errors::Unknown(
"Failed to get convolution algorithm. This is probably "
"because MIOpen failed to initialize, so try looking to "
"see if a warning log message was printed above."));
std::vector<tensorflow::AutotuneResult> results;
if (algorithms.size() == 1) {
auto profile_result = algorithms[0];
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(
profile_result.algorithm().algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_result.algorithm().tensor_ops_enabled());
result.set_scratch_bytes(profile_result.scratch_size());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
} else {
for (auto miopen_algorithm : algorithms) {
auto profile_algorithm = miopen_algorithm.algorithm();
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
ctx);
ProfileResult profile_result;
bool miopen_launch_status = true;
miopen_launch_status =
stream
->ThenConvolveBackwardDataWithAlgorithm(
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
AlgorithmConfig(profile_algorithm), &profile_result)
.ok();
if (miopen_launch_status && profile_result.is_valid()) {
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
}
}
}
#endif
LogConvAutotuneResults(
se::dnn::ConvolutionKind::BACKWARD_DATA, se::dnn::ToDataType<T>::value,
in_backprop_ptr, filter_ptr, out_backprop_ptr, input_desc, filter_desc,
output_desc, conv_desc, stream->parent(), results);
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
#elif TENSORFLOW_USE_ROCM
// MIOpen has its own Find and autotuner so use it here, passing
// default AlgorithmConfig to force a search
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
ProfileResult best_result;
bool miopen_find_status =
stream
->ThenConvolveBackwardDataWithAlgorithm(
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
AlgorithmConfig(), &best_result)
.ok();
OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(),
errors::NotFound("Failed to find backwards-data algorithm!"));
algorithm_config.set_algorithm(best_result.algorithm());
algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
algorithm_config);
}

View File

@ -1433,6 +1433,51 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
}
}
}
#elif TENSORFLOW_USE_ROCM
std::vector<ProfileResult> algorithms;
CHECK(stream->parent()->GetMIOpenConvolveAlgorithms(
se::dnn::ConvolutionKind::BACKWARD_DATA, stream,
se::dnn::ToDataType<T>::value, input_desc, filter_desc, conv_desc,
output_desc, &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
std::vector<tensorflow::AutotuneResult> results;
for (auto miopen_algorithm : algorithms) {
auto profile_algorithm = miopen_algorithm.algorithm();
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
context);
ProfileResult profile_result;
bool miopen_launch_status =
stream
->ThenConvolveBackwardDataWithAlgorithm(
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
AlgorithmConfig(profile_algorithm), &profile_result)
.ok();
if (miopen_launch_status) {
if (profile_result.is_valid()) {
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
if (scratch_allocator.TotalByteSize() == 0 &&
profile_result.elapsed_time_in_ms() <
best_result_no_scratch.elapsed_time_in_ms()) {
best_result_no_scratch = profile_result;
}
}
}
}
#endif
LogConvAutotuneResults(se::dnn::ConvolutionKind::BACKWARD_DATA,
se::dnn::ToDataType<T>::value, in_backprop_ptr,
filter_ptr, out_backprop_ptr, input_desc,
@ -1448,22 +1493,6 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
algorithm_config.set_algorithm_no_scratch(
best_result_no_scratch.algorithm());
}
#elif TENSORFLOW_USE_ROCM
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
context);
ProfileResult best_result;
bool miopen_find_status =
stream
->ThenConvolveBackwardDataWithAlgorithm(
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
AlgorithmConfig(), &best_result)
.ok();
OP_REQUIRES(context, miopen_find_status && best_result.is_valid(),
errors::NotFound("Failed to find backward data algorithm!"));
algorithm_config.set_algorithm(best_result.algorithm());
algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
AutoTuneConv3dBwdData::GetInstance()->Insert(conv_parameters,
algorithm_config);
}
@ -1864,6 +1893,46 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
}
}
}
#elif TENSORFLOW_USE_ROCM
std::vector<ProfileResult> algorithms;
CHECK(stream->parent()->GetMIOpenConvolveAlgorithms(
se::dnn::ConvolutionKind::BACKWARD_FILTER, stream,
se::dnn::ToDataType<T>::value, input_desc, filter_desc, conv_desc,
output_desc, &algorithms));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
if (algorithms.size() == 1) {
best_result = algorithms[0];
} else {
for (auto miopen_algorithm : algorithms) {
auto profile_algorithm = miopen_algorithm.algorithm();
DnnScratchAllocator scratch_allocator(
ConvolveBackwardFilterScratchSize, context);
ProfileResult profile_result;
bool cudnn_launch_status =
stream
->ThenConvolveBackwardFilterWithAlgorithm(
input_desc, input_ptr, output_desc, out_backprop_ptr,
conv_desc, filter_desc, &filter_backprop_ptr,
&scratch_allocator, AlgorithmConfig(profile_algorithm),
&profile_result)
.ok();
if (cudnn_launch_status) {
if (profile_result.is_valid()) {
if (profile_result.elapsed_time_in_ms() <
best_result.elapsed_time_in_ms()) {
best_result = profile_result;
}
if (scratch_allocator.TotalByteSize() == 0 &&
profile_result.elapsed_time_in_ms() <
best_result_no_scratch.elapsed_time_in_ms()) {
best_result_no_scratch = profile_result;
}
}
}
}
}
#endif
OP_REQUIRES(context,
best_result.is_valid() || best_result_no_scratch.is_valid(),
errors::NotFound("No algorithm worked!"));
@ -1874,23 +1943,6 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
algorithm_config.set_algorithm_no_scratch(
best_result_no_scratch.algorithm());
}
#elif TENSORFLOW_USE_ROCM
DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
context);
ProfileResult best_result;
bool miopen_find_status =
stream
->ThenConvolveBackwardFilterWithAlgorithm(
input_desc, input_ptr, output_desc, out_backprop_ptr,
conv_desc, filter_desc, &filter_backprop_ptr,
&scratch_allocator, AlgorithmConfig(), &best_result)
.ok();
OP_REQUIRES(
context, miopen_find_status && best_result.is_valid(),
errors::NotFound("Failed to find backward filter algorithm!"));
algorithm_config.set_algorithm(best_result.algorithm());
algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
AutoTuneConv3dBwdFilter::GetInstance()->Insert(conv_parameters,
algorithm_config);
}

View File

@ -1039,28 +1039,65 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
CheckRedzones(rz_allocator, &result);
}
}
#elif TENSORFLOW_USE_ROCM
std::vector<ProfileResult> algorithms;
OP_REQUIRES(ctx,
stream->parent()->GetMIOpenConvolveAlgorithms(
se::dnn::ConvolutionKind::FORWARD, stream,
se::dnn::ToDataType<T>::value, input_desc, filter_desc,
conv_desc, output_desc, &algorithms),
errors::Unknown(
"Failed to get convolution algorithm. This is probably "
"because MIOpen failed to initialize, so try looking to "
"see if a warning log message was printed above."));
se::DeviceMemory<T> output_tensor = output_ptr;
std::vector<tensorflow::AutotuneResult> results;
if (algorithms.size() == 1) {
auto profile_result = algorithms[0];
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(
profile_result.algorithm().algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_result.algorithm().tensor_ops_enabled());
result.set_scratch_bytes(profile_result.scratch_size());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
} else {
for (auto miopen_algorithm : algorithms) {
auto profile_algorithm = miopen_algorithm.algorithm();
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
ProfileResult profile_result;
bool miopen_launch_status = false;
miopen_launch_status =
stream
->ThenConvolveWithAlgorithm(
input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
output_desc, &output_ptr, &scratch_allocator,
AlgorithmConfig(profile_algorithm), &profile_result)
.ok();
if (miopen_launch_status && profile_result.is_valid()) {
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
}
}
}
#endif
LogConvAutotuneResults(se::dnn::ConvolutionKind::FORWARD,
se::dnn::ToDataType<T>::value, input_ptr, filter_ptr,
output_tensor, input_desc, filter_desc, output_desc,
conv_desc, stream->parent(), results);
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
#elif TENSORFLOW_USE_ROCM
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
ProfileResult best_result;
bool miopen_find_status =
stream
->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc,
filter_ptr, conv_desc, output_desc,
&output_ptr, &scratch_allocator,
AlgorithmConfig(), &best_result)
.ok();
OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(),
errors::NotFound("Failed to find conv algorithm!"));
algorithm_config.set_algorithm(best_result.algorithm());
algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
AutoTuneConv::GetInstance()->Insert(conv_parameters, algorithm_config);
}

View File

@ -504,26 +504,63 @@ struct LaunchConvOp<GPUDevice, T> {
}
}
}
#elif TENSORFLOW_USE_ROCM
std::vector<ProfileResult> algorithms;
OP_REQUIRES(ctx,
stream->parent()->GetMIOpenConvolveAlgorithms(
se::dnn::ConvolutionKind::FORWARD, stream,
se::dnn::ToDataType<T>::value, input_desc, filter_desc,
conv_desc, output_desc, &algorithms),
errors::Unknown(
"Failed to get convolution algorithm. This is probably "
"because MIOpen failed to initialize, so try looking to "
"see if a warning log message was printed above."));
std::vector<tensorflow::AutotuneResult> results;
if (algorithms.size() == 1) {
auto profile_result = algorithms[0];
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(
profile_result.algorithm().algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_result.algorithm().tensor_ops_enabled());
result.set_scratch_bytes(profile_result.scratch_size());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
} else {
for (auto miopen_algorithm : algorithms) {
auto profile_algorithm = miopen_algorithm.algorithm();
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
ProfileResult profile_result;
bool miopen_launch_status =
stream
->ThenConvolveWithAlgorithm(
input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
output_desc, &output_ptr, &scratch_allocator,
AlgorithmConfig(profile_algorithm), &profile_result)
.ok();
if (miopen_launch_status) {
if (profile_result.is_valid()) {
results.emplace_back();
auto& result = results.back();
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
result.mutable_conv()->set_tensor_ops_enabled(
profile_algorithm.tensor_ops_enabled());
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
*result.mutable_run_time() = proto_utils::ToDurationProto(
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
}
}
}
}
#endif
LogConvAutotuneResults(se::dnn::ConvolutionKind::FORWARD,
se::dnn::ToDataType<T>::value, input_ptr,
filter_ptr, output_ptr, input_desc, filter_desc,
output_desc, conv_desc, stream->parent(), results);
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
#elif TENSORFLOW_USE_ROCM
ProfileResult best_result;
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
bool miopen_find_status =
stream
->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc,
filter_ptr, conv_desc, output_desc,
&output_ptr, &scratch_allocator,
AlgorithmConfig(), &best_result)
.ok();
OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(),
errors::NotFound("Failed to find conv algorithm!"));
algorithm_config.set_algorithm(best_result.algorithm());
algorithm_config.set_scratch_size(best_result.scratch_size());
#endif
AutoTuneConv3d::GetInstance()->Insert(conv_parameters, algorithm_config);
}

View File

@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/core/kernels/gpu_utils.h"
#if GOOGLE_CUDA
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
#include <iterator>
@ -249,4 +249,4 @@ Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
} // namespace tensorflow
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

View File

@ -41,6 +41,17 @@ bool DnnSupport::GetConvolveAlgorithms(
return false;
}
bool DnnSupport::GetMIOpenConvolveAlgorithms(
dnn::ConvolutionKind /*kind*/, Stream* /*stream*/,
dnn::DataType /*element_type*/,
const dnn::BatchDescriptor& /*input_descriptor*/,
const dnn::FilterDescriptor& /*filter_descriptor*/,
const dnn::ConvolutionDescriptor& /*convolution_descriptor*/,
const dnn::BatchDescriptor& /*output_descriptor*/,
std::vector<ProfileResult>* /*out_algorithms*/) {
return false;
}
bool DnnSupport::GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms) {
return false;
}

View File

@ -1352,6 +1352,14 @@ class DnnSupport {
bool with_winograd_nonfused, int cc_major, int cc_minor,
std::vector<AlgorithmDesc>* out_algorithms);
virtual bool GetMIOpenConvolveAlgorithms(
dnn::ConvolutionKind kind, Stream* stream, dnn::DataType element_type,
const dnn::BatchDescriptor& input_descriptor,
const dnn::FilterDescriptor& filter_descriptor,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& output_descriptor,
std::vector<ProfileResult>* out_algorithms);
// Returns a list of supported rnn algorithms.
virtual bool GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms);

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "rocm/include/miopen/miopen.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
@ -53,6 +54,8 @@ NarrowT CheckedNarrowing(const WideT& wide) {
return narrow;
}
const int kImmediateModeVlogLevel = 3;
} // namespace
namespace stream_executor {
@ -91,6 +94,24 @@ string ToString(miopenStatus_t status) {
}
}
string ToString(miopenConvAlgorithm_t algorithm) {
string s;
switch (algorithm) {
case miopenConvolutionAlgoGEMM:
s = "GEMM";
break;
case miopenConvolutionAlgoDirect:
s = "Direct";
break;
case miopenConvolutionAlgoFFT:
s = "FFT";
break;
case miopenConvolutionAlgoWinograd:
s = "Winograd";
break;
}
return s;
}
// RAII wrapper for all calls to MIOpen with a MIOpen handle argument.
//
// See MIOpenAccess::GetHandle() for details.
@ -156,95 +177,110 @@ namespace wrap {
#endif
// clang-format off
#define MIOPEN_DNN_ROUTINE_EACH(__macro) \
__macro(miopenBatchNormalizationBackward) \
__macro(miopenBatchNormalizationForwardInference) \
__macro(miopenBatchNormalizationForwardTraining) \
__macro(miopenGetConvolutionForwardOutputDim) \
__macro(miopenGetConvolutionNdForwardOutputDim) \
__macro(miopenFindConvolutionForwardAlgorithm) \
__macro(miopenCreateTensorDescriptor) \
__macro(miopenDestroyTensorDescriptor) \
__macro(miopenSet2dPoolingDescriptor) \
__macro(miopenSetLRNDescriptor) \
__macro(miopenLRNGetWorkSpaceSize) \
__macro(miopenCreateConvolutionDescriptor) \
__macro(miopenCreatePoolingDescriptor) \
__macro(miopenDestroyPoolingDescriptor) \
__macro(miopenCreateLRNDescriptor) \
__macro(miopenDestroyLRNDescriptor) \
__macro(miopenDestroyConvolutionDescriptor) \
__macro(miopenCreateWithStream) \
__macro(miopenDestroy) \
__macro(miopenSetStream) \
__macro(miopenSetAllocator) \
__macro(miopenActivationForward) \
__macro(miopenConvolutionForward) \
__macro(miopenConvolutionBackwardBias) \
__macro(miopenConvolutionForwardGetWorkSpaceSize) \
__macro(miopenInitConvolutionDescriptor) \
__macro(miopenInitConvolutionNdDescriptor) \
__macro(miopenGetConvolutionDescriptor) \
__macro(miopenGetConvolutionNdDescriptor) \
__macro(miopenSetConvolutionGroupCount) \
__macro(miopenSet4dTensorDescriptor) \
__macro(miopenGetTensorDescriptor) \
__macro(miopenSetTensorDescriptor) \
__macro(miopenGetTensorDescriptorSize) \
__macro(miopenPoolingForward) \
__macro(miopenPoolingGetWorkSpaceSize) \
__macro(miopenPoolingBackward) \
__macro(miopenLRNForward) \
__macro(miopenLRNBackward) \
__macro(miopenOpTensor) \
__macro(miopenConvolutionBackwardData) \
__macro(miopenConvolutionBackwardWeights) \
__macro(miopenConvolutionBackwardWeightsGetWorkSpaceSize)\
__macro(miopenFindConvolutionBackwardDataAlgorithm) \
__macro(miopenFindConvolutionBackwardWeightsAlgorithm) \
__macro(miopenConvolutionBackwardDataGetWorkSpaceSize) \
__macro(miopenCreateRNNDescriptor) \
__macro(miopenSetRNNDescriptor) \
__macro(miopenDestroyRNNDescriptor) \
__macro(miopenGetRNNParamsSize) \
__macro(miopenGetRNNLayerParam) \
__macro(miopenGetRNNLayerBias) \
__macro(miopenGetRNNWorkspaceSize) \
__macro(miopenGetRNNTrainingReserveSize) \
__macro(miopenRNNForwardInference) \
__macro(miopenRNNForwardTraining) \
__macro(miopenRNNBackwardData) \
__macro(miopenRNNBackwardWeights) \
__macro(miopenGetRNNLayerParamOffset) \
__macro(miopenGetRNNLayerParamSize) \
__macro(miopenGetRNNLayerBiasOffset) \
__macro(miopenGetRNNLayerBiasSize) \
__macro(miopenGetRNNParamsDescriptor) \
__macro(miopenCreateActivationDescriptor) \
__macro(miopenSetActivationDescriptor) \
__macro(miopenGetActivationDescriptor) \
__macro(miopenDestroyActivationDescriptor) \
__macro(miopenCreateFusionPlan) \
__macro(miopenCreateOpConvForward) \
__macro(miopenCreateOpBiasForward) \
__macro(miopenCreateOpActivationForward) \
__macro(miopenCreateOpActivationBackward) \
__macro(miopenCreateOpBatchNormInference) \
__macro(miopenCreateOpBatchNormForward) \
__macro(miopenCreateOpBatchNormBackward) \
__macro(miopenCompileFusionPlan) \
__macro(miopenFusionPlanGetOp) \
__macro(miopenCreateOperatorArgs) \
__macro(miopenSetOpArgsConvForward) \
__macro(miopenSetOpArgsBiasForward) \
__macro(miopenSetOpArgsActivForward) \
__macro(miopenSetOpArgsActivBackward) \
__macro(miopenSetOpArgsBatchNormInference) \
__macro(miopenSetOpArgsBatchNormForward) \
__macro(miopenSetOpArgsBatchNormBackward) \
__macro(miopenExecuteFusionPlan) \
__macro(miopenDestroyOperatorArgs) \
__macro(miopenDestroyFusionPlan)
#define MIOPEN_DNN_ROUTINE_EACH(__macro) \
__macro(miopenBatchNormalizationBackward) \
__macro(miopenBatchNormalizationForwardInference) \
__macro(miopenBatchNormalizationForwardTraining) \
__macro(miopenGetConvolutionForwardOutputDim) \
__macro(miopenGetConvolutionNdForwardOutputDim) \
__macro(miopenFindConvolutionForwardAlgorithm) \
__macro(miopenCreateTensorDescriptor) \
__macro(miopenDestroyTensorDescriptor) \
__macro(miopenSet2dPoolingDescriptor) \
__macro(miopenSetLRNDescriptor) \
__macro(miopenLRNGetWorkSpaceSize) \
__macro(miopenCreateConvolutionDescriptor) \
__macro(miopenCreatePoolingDescriptor) \
__macro(miopenDestroyPoolingDescriptor) \
__macro(miopenCreateLRNDescriptor) \
__macro(miopenDestroyLRNDescriptor) \
__macro(miopenDestroyConvolutionDescriptor) \
__macro(miopenCreateWithStream) \
__macro(miopenDestroy) \
__macro(miopenSetStream) \
__macro(miopenSetAllocator) \
__macro(miopenActivationForward) \
__macro(miopenConvolutionForward) \
__macro(miopenConvolutionBackwardBias) \
__macro(miopenConvolutionForwardGetWorkSpaceSize) \
__macro(miopenInitConvolutionDescriptor) \
__macro(miopenInitConvolutionNdDescriptor) \
__macro(miopenGetConvolutionDescriptor) \
__macro(miopenGetConvolutionNdDescriptor) \
__macro(miopenSetConvolutionGroupCount) \
__macro(miopenSet4dTensorDescriptor) \
__macro(miopenGetTensorDescriptor) \
__macro(miopenSetTensorDescriptor) \
__macro(miopenGetTensorDescriptorSize) \
__macro(miopenPoolingForward) \
__macro(miopenPoolingGetWorkSpaceSize) \
__macro(miopenPoolingBackward) \
__macro(miopenLRNForward) \
__macro(miopenLRNBackward) \
__macro(miopenOpTensor) \
__macro(miopenConvolutionBackwardData) \
__macro(miopenConvolutionBackwardWeights) \
__macro(miopenConvolutionBackwardWeightsGetWorkSpaceSize) \
__macro(miopenFindConvolutionBackwardDataAlgorithm) \
__macro(miopenFindConvolutionBackwardWeightsAlgorithm) \
__macro(miopenConvolutionBackwardDataGetWorkSpaceSize) \
__macro(miopenCreateRNNDescriptor) \
__macro(miopenSetRNNDescriptor) \
__macro(miopenDestroyRNNDescriptor) \
__macro(miopenGetRNNParamsSize) \
__macro(miopenGetRNNLayerParam) \
__macro(miopenGetRNNLayerBias) \
__macro(miopenGetRNNWorkspaceSize) \
__macro(miopenGetRNNTrainingReserveSize) \
__macro(miopenRNNForwardInference) \
__macro(miopenRNNForwardTraining) \
__macro(miopenRNNBackwardData) \
__macro(miopenRNNBackwardWeights) \
__macro(miopenGetRNNLayerParamOffset) \
__macro(miopenGetRNNLayerParamSize) \
__macro(miopenGetRNNLayerBiasOffset) \
__macro(miopenGetRNNLayerBiasSize) \
__macro(miopenGetRNNParamsDescriptor) \
__macro(miopenCreateActivationDescriptor) \
__macro(miopenSetActivationDescriptor) \
__macro(miopenGetActivationDescriptor) \
__macro(miopenDestroyActivationDescriptor) \
__macro(miopenCreateFusionPlan) \
__macro(miopenCreateOpConvForward) \
__macro(miopenCreateOpBiasForward) \
__macro(miopenCreateOpActivationForward) \
__macro(miopenCreateOpActivationBackward) \
__macro(miopenCreateOpBatchNormInference) \
__macro(miopenCreateOpBatchNormForward) \
__macro(miopenCreateOpBatchNormBackward) \
__macro(miopenCompileFusionPlan) \
__macro(miopenFusionPlanGetOp) \
__macro(miopenCreateOperatorArgs) \
__macro(miopenSetOpArgsConvForward) \
__macro(miopenSetOpArgsBiasForward) \
__macro(miopenSetOpArgsActivForward) \
__macro(miopenSetOpArgsActivBackward) \
__macro(miopenSetOpArgsBatchNormInference) \
__macro(miopenSetOpArgsBatchNormForward) \
__macro(miopenSetOpArgsBatchNormBackward) \
__macro(miopenExecuteFusionPlan) \
__macro(miopenDestroyOperatorArgs) \
__macro(miopenDestroyFusionPlan) \
__macro(miopenConvolutionForwardGetSolutionCount) \
__macro(miopenConvolutionForwardGetSolution) \
__macro(miopenConvolutionForwardGetSolutionWorkspaceSize) \
__macro(miopenConvolutionForwardCompileSolution) \
__macro(miopenConvolutionForwardImmediate) \
__macro(miopenConvolutionBackwardDataGetSolutionCount) \
__macro(miopenConvolutionBackwardDataGetSolution) \
__macro(miopenConvolutionBackwardDataGetSolutionWorkspaceSize) \
__macro(miopenConvolutionBackwardDataCompileSolution) \
__macro(miopenConvolutionBackwardDataImmediate) \
__macro(miopenConvolutionBackwardWeightsGetSolutionCount) \
__macro(miopenConvolutionBackwardWeightsGetSolution) \
__macro(miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize) \
__macro(miopenConvolutionBackwardWeightsCompileSolution) \
__macro(miopenConvolutionBackwardWeightsImmediate)
// clang-format on
@ -389,6 +425,15 @@ absl::Mutex CachedFusionPlans::cached_plans_mutex;
std::map<uint64, miopenFusionPlanDescriptor_t> CachedFusionPlans::cached_plans;
std::set<uint64> CachedFusionPlans::unsupported_plans;
dnn::ProfileResult GetProfileResultFromConvSolution(
miopenConvSolution_t solution) {
dnn::ProfileResult profile_result;
profile_result.set_algorithm({solution.solution_id, false});
profile_result.set_elapsed_time_in_ms(solution.time);
profile_result.set_scratch_size(solution.workspace_size);
return profile_result;
}
} // namespace
namespace {
@ -2617,126 +2662,73 @@ port::Status MIOpenSupport::DoPrepareForConvolution(
auto miopen = miopen_->GetHandle(parent_, stream);
absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
size_t scratch_memory_size;
absl::optional<dnn::AlgorithmDesc> input_algo_desc =
algorithm_config.algorithm();
if (!algo_desc.has_value()) {
// With the default algorithm, use MIOpen's heuristics.
assert(scratch_allocator);
assert(input_algo_desc.has_value());
DeviceMemory<uint8> scratch_memory_temp;
MIOpenAllocatorContext mac(scratch_allocator, stream);
wrap::miopenSetAllocator(miopen.handle(), MIOpenAllocatorCallback,
MIOpenDeallocatorCallback, &mac);
size_t size_in_bytes;
miopenStatus_t status = miopenStatusSuccess;
// An algorithm has been specified.
*algorithm_desc = *input_algo_desc;
switch (kind) {
case dnn::ConvolutionKind::FORWARD: {
status = wrap::miopenConvolutionForwardGetWorkSpaceSize(
miopen.handle(), /*filterDesc=*/filter.handle(),
/*srcDesc=*/input_nd.handle(), /*convDesc=*/conv.handle(),
/*destDesc=*/output_nd.handle(), /*sizeInBytes=*/&size_in_bytes);
break;
const uint64_t solution_id = algorithm_desc->algo_id();
size_t scratch_memory_size = 0;
switch (kind) {
case dnn::ConvolutionKind::FORWARD: {
auto status = wrap::miopenConvolutionForwardGetSolutionWorkspaceSize(
miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
output_nd.handle(), solution_id, &scratch_memory_size);
if (status != miopenStatusSuccess) {
return port::InternalError(absl::StrCat(
"call to miopenConvolutionForwardGetSolutionWorkspaceSize "
"failed: ",
ToString(status)));
}
case dnn::ConvolutionKind::BACKWARD_DATA: {
status = wrap::miopenConvolutionBackwardDataGetWorkSpaceSize(
miopen.handle(), /*diffDesc=*/output_nd.handle(),
/*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
/*gradDesc=*/input_nd.handle(), /*sizeInBytes=*/&size_in_bytes);
break;
}
case dnn::ConvolutionKind::BACKWARD_FILTER: {
status = wrap::miopenConvolutionBackwardWeightsGetWorkSpaceSize(
miopen.handle(), /*diffDesc=*/output_nd.handle(),
/*srcDesc=*/input_nd.handle(), /*convDesc=*/conv.handle(),
/*gradDesc=*/filter.handle(), /*sizeInBytes=*/&size_in_bytes);
break;
}
default:
return port::InternalError(absl::StrCat("Unexpected convolution kind ",
static_cast<int>(kind)));
break;
}
if (status == miopenStatusSuccess && size_in_bytes != 0) {
auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
if (allocated.ok()) {
scratch_memory_temp = allocated.ValueOrDie();
case dnn::ConvolutionKind::BACKWARD_DATA: {
auto status = wrap::miopenConvolutionBackwardDataGetSolutionWorkspaceSize(
miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
input_nd.handle(), solution_id, &scratch_memory_size);
if (status != miopenStatusSuccess) {
return port::InternalError(absl::StrCat(
"call to miopenConvolutionabckwardDataGetSolutionWorkspaceSize "
"failed: ",
ToString(status)));
}
break;
}
miopenConvAlgoPerf_t preference;
int returnedAlgoCount;
case dnn::ConvolutionKind::BACKWARD_FILTER: {
auto status =
wrap::miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize(
miopen.handle(), output_nd.handle(), input_nd.handle(),
conv.handle(), filter.handle(), solution_id,
&scratch_memory_size);
switch (kind) {
case dnn::ConvolutionKind::FORWARD: {
auto status = wrap::miopenFindConvolutionForwardAlgorithm(
miopen.handle(), input_nd.handle(), input_data.opaque(),
filter.handle(), filter_data.opaque(), conv.handle(),
output_nd.handle(), output_data.opaque(),
/*requestAlgoCount=*/1, &returnedAlgoCount,
/*preference=*/&preference,
/*workspace*/ scratch_memory_temp.opaque(),
/*WorkSpaceSize*/ scratch_memory_temp.size(),
/*exhaustiveSearch*/ false);
CHECK_EQ(status, miopenStatusSuccess) << "Unable to find a suitable "
"algorithm for doing forward "
"convolution";
*algorithm_desc = dnn::AlgorithmDesc(preference.fwd_algo, false);
break;
if (status != miopenStatusSuccess) {
return port::InternalError(absl::StrCat(
"call to miopenConvolutionabckwardWeightsGetSolutionWorkspaceSize "
"failed: ",
ToString(status)));
}
case dnn::ConvolutionKind::BACKWARD_DATA: {
auto status = wrap::miopenFindConvolutionBackwardDataAlgorithm(
miopen.handle(),
/*diffDesc=*/output_nd.handle(), output_data.opaque(),
/*filterDesc=*/filter.handle(), filter_data.opaque(),
/*convDesc=*/conv.handle(),
/*gradDesc=*/input_nd.handle(), input_data.opaque(),
/*requestCount=*/1, /*returnedAlgoCount=*/&returnedAlgoCount,
/*preference=*/&preference,
/*WorkSpace=*/scratch_memory_temp.opaque(),
/*WorkSpaceSize=*/scratch_memory_temp.size(),
/*exhaustiveSearch=*/false);
CHECK_EQ(status, miopenStatusSuccess) << "Unable to find a suitable "
"algorithm for doing backward "
"data convolution";
*algorithm_desc = dnn::AlgorithmDesc(preference.bwd_data_algo, false);
break;
}
case dnn::ConvolutionKind::BACKWARD_FILTER: {
auto status = wrap::miopenFindConvolutionBackwardWeightsAlgorithm(
miopen.handle(),
/*diffDesc=*/output_nd.handle(), output_data.opaque(),
/*srcDesc=*/input_nd.handle(), input_data.opaque(),
/*convDesc=*/conv.handle(),
/*gradDesc=*/filter.handle(), filter_data.opaque(),
/*requestAlgoCount=*/1, /*returnedAlgoCount=*/&returnedAlgoCount,
/*preference=*/&preference,
/*WorkSpace=*/scratch_memory_temp.opaque(),
/*WorkSpaceSize=*/scratch_memory_temp.size(),
/*exhaustiveSearch=*/false);
CHECK_EQ(status, miopenStatusSuccess) << "Unable to find a suitable "
"algorithm for doing backward "
"filter convolution";
*algorithm_desc =
dnn::AlgorithmDesc(preference.bwd_weights_algo, false);
break;
}
default:
return port::InternalError(absl::StrCat("Unexpected convolution kind ",
static_cast<int>(kind)));
break;
}
// Restore default allocator, note mac is stack temp
wrap::miopenSetAllocator(miopen.handle(), nullptr, nullptr, nullptr);
scratch_memory_size = preference.memory;
} else {
// An algorithm has been specified.
*algorithm_desc = *algo_desc;
scratch_memory_size = *(algorithm_config.scratch_size());
default: {
return port::InternalError(
absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
break;
}
}
VLOG(2) << "miopen...GetSolutionWorkspaceSize returned "
<< scratch_memory_size << " for solution_id " << solution_id;
// allocate scratch memory
if (scratch_memory_size != 0) {
if (scratch_allocator == nullptr) {
@ -2745,12 +2737,18 @@ port::Status MIOpenSupport::DoPrepareForConvolution(
"needed"));
}
auto allocated = scratch_allocator->AllocateBytes(scratch_memory_size);
if (!allocated.ok()) {
return port::InternalError(absl::StrCat(
"Failed to allocate scratch memory of size: ", scratch_memory_size));
}
if (allocated.ok()) {
*scratch_memory = allocated.ValueOrDie();
} else {
LOG(ERROR)
<< "Failed to allocate scratch memory - "
<< allocated.status().error_message() << "\n"
<< "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
"larger number (e.g. 8192) to increase the max memory limit.\n"
<< "\tIncreasing the max memory limit might help resolve this "
"error";
return port::InternalError(absl::StrCat(
"Failed to allocate scratch memory of size: ", scratch_memory_size));
}
}
@ -2846,20 +2844,17 @@ port::Status MIOpenSupport::DoConvolve(
}
}
const uint64_t solution_id = algorithm_desc.algo_id();
miopenStatus_t status = miopenStatusSuccess;
switch (kind) {
case dnn::ConvolutionKind::FORWARD: {
status = wrap::miopenConvolutionForward(
miopen.handle(),
/*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(),
/*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
/*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
/*algo=*/
static_cast<miopenConvFwdAlgorithm_t>(algorithm_desc.algo_id()),
/*beta=*/&beta, /*destDesc=*/output_nd.handle(),
/*destData=*/output_data.opaque(),
/*workSpace=*/scratch_memory.opaque(),
/*workSpaceSizeInBytes=*/scratch_memory.size());
status = wrap::miopenConvolutionForwardImmediate(
miopen.handle(), filter.handle(), filter_data.opaque(),
input_nd.handle(), input_data.opaque(), conv.handle(),
output_nd.handle(), output_data.opaque(), scratch_memory.opaque(),
scratch_memory.size(), solution_id);
break;
}
case dnn::ConvolutionKind::BACKWARD_DATA: {
@ -2871,21 +2866,11 @@ port::Status MIOpenSupport::DoConvolve(
stream, miopen.handle(), ToMIOpenDataType(element_type),
&output_back_descriptor, output_data, &transform_scratch);
status = wrap::miopenConvolutionBackwardData(
miopen.handle(),
/*alpha=*/&alpha,
/*diffDesc=*/output_nd.handle(),
/*diffData=*/output_data.opaque(),
/*filterDesc=*/filter.handle(),
/*filterData=*/filter_data.opaque(),
/*convDesc=*/conv.handle(),
/*algo=*/
static_cast<miopenConvBwdDataAlgorithm_t>(algorithm_desc.algo_id()),
/*beta=*/&beta,
/*gradDesc=*/input_nd.handle(),
/*gradData=*/input_data.opaque(),
/*workSpace=*/scratch_memory.opaque(),
/*workSpaceSizeInBytes=*/scratch_memory.size());
status = wrap::miopenConvolutionBackwardDataImmediate(
miopen.handle(), output_nd.handle(), output_data.opaque(),
filter.handle(), filter_data.opaque(), conv.handle(),
input_nd.handle(), input_data.opaque(), scratch_memory.opaque(),
scratch_memory.size(), solution_id);
break;
}
case dnn::ConvolutionKind::BACKWARD_FILTER: {
@ -2897,22 +2882,11 @@ port::Status MIOpenSupport::DoConvolve(
stream, miopen.handle(), ToMIOpenDataType(element_type),
&output_back_descriptor, output_data, &transform_scratch);
status = wrap::miopenConvolutionBackwardWeights(
miopen.handle(),
/*alpha=*/&alpha,
/*diffDesc=*/output_nd.handle(),
/*diffData=*/output_data.opaque(),
/*srcDesc=*/input_nd.handle(),
/*srcData=*/input_data.opaque(),
/*convDesc=*/conv.handle(),
/*algo=*/
static_cast<miopenConvBwdWeightsAlgorithm_t>(
algorithm_desc.algo_id()),
/*beta=*/&beta,
/*gradDesc=*/filter.handle(),
/*gradData=*/filter_data.opaque(),
/*workSpace=*/scratch_memory.opaque(),
/*workSpaceSizeInBytes=*/scratch_memory.size());
status = wrap::miopenConvolutionBackwardWeightsImmediate(
miopen.handle(), output_nd.handle(), output_data.opaque(),
input_nd.handle(), input_data.opaque(), conv.handle(),
filter.handle(), filter_data.opaque(), scratch_memory.opaque(),
scratch_memory.size(), solution_id);
break;
}
default:
@ -2958,6 +2932,312 @@ bool MIOpenSupport::GetConvolveAlgorithms(
return true;
}
bool MIOpenSupport::GetMIOpenConvolveAlgorithms(
dnn::ConvolutionKind kind, Stream* stream, dnn::DataType element_type,
const dnn::BatchDescriptor& input_descriptor,
const dnn::FilterDescriptor& filter_descriptor,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& output_descriptor,
std::vector<dnn::ProfileResult>* out_algorithms) {
auto miopen = miopen_->GetHandle(parent_, stream);
ScopedTensorDescriptor input_nd{input_descriptor,
ToMIOpenDataType(element_type)};
ScopedTensorDescriptor output_nd{output_descriptor,
ToMIOpenDataType(element_type)};
ScopedFilterDescriptor filter{filter_descriptor, input_descriptor,
ToMIOpenDataType(element_type)};
ScopedConvolutionDescriptor conv{convolution_descriptor,
ToMIOpenDataType(element_type)};
// First determine the number of algorityhms available
size_t maxSolutionCount = 0;
switch (kind) {
case dnn::ConvolutionKind::FORWARD: {
auto status = wrap::miopenConvolutionForwardGetSolutionCount(
miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
output_nd.handle(), &maxSolutionCount);
if (status != miopenStatusSuccess) {
LOG(FATAL)
<< "call to miopenConvolutionForwardGetSolutionCount failed: "
<< ToString(status);
return false;
}
break;
}
case dnn::ConvolutionKind::BACKWARD_DATA: {
auto status = wrap::miopenConvolutionBackwardDataGetSolutionCount(
miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
input_nd.handle(), &maxSolutionCount);
if (status != miopenStatusSuccess) {
LOG(FATAL)
<< "call to miopenConvolutionBackwardDataGetSolutionCount failed: "
<< ToString(status);
return false;
}
break;
}
case dnn::ConvolutionKind::BACKWARD_FILTER: {
auto status = wrap::miopenConvolutionBackwardWeightsGetSolutionCount(
miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
filter.handle(), &maxSolutionCount);
if (status != miopenStatusSuccess) {
LOG(FATAL)
<< "call to miopenConvolutionBackwardWeightsGetSolutionCount "
"failed: "
<< ToString(status);
return false;
}
break;
}
default: {
LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
return false;
break;
}
}
VLOG(kImmediateModeVlogLevel)
<< "Number of conv solutions max: " << maxSolutionCount;
// if the env var TF_ROCM_MIMIC_FIND_MODE is set, determine the best solution
// as per the "runtime" information for each solution (returned by the prior
// call to the *GetSolution api), and then return only the best solution
// The idea here is to mimic the old "find" mode, in which we relied upon
// the miopen api to determine the best solution, and use that solution
// without doing any further measurement in the TF layer
bool mimic_find_mode = false;
tensorflow::ReadBoolFromEnvVar("TF_ROCM_MIMIC_FIND_MODE", false,
&mimic_find_mode);
size_t solutionCount = 0;
std::unique_ptr<miopenConvSolution_t[]> solutions(
new miopenConvSolution_t[maxSolutionCount]);
switch (kind) {
case dnn::ConvolutionKind::FORWARD: {
auto status = wrap::miopenConvolutionForwardGetSolution(
miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
output_nd.handle(), maxSolutionCount, &solutionCount,
solutions.get());
if (status != miopenStatusSuccess) {
LOG(FATAL) << "call to miopenConvolutionForwardGetSolution failed: "
<< ToString(status);
return false;
}
VLOG(kImmediateModeVlogLevel)
<< "Number of conv solutions actual: " << solutionCount;
if (mimic_find_mode) {
miopenConvSolution_t best_solution = solutions[0];
for (int i = 1; i < solutionCount; i++) {
miopenConvSolution_t solution = solutions[i];
if (solution.time < best_solution.time) {
best_solution = solution;
}
}
VLOG(kImmediateModeVlogLevel)
<< "Best Solution (id, algo) = " << best_solution.solution_id
<< ", " << ToString(best_solution.algorithm);
status = wrap::miopenConvolutionForwardCompileSolution(
miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
output_nd.handle(), best_solution.solution_id);
if (status != miopenStatusSuccess) {
LOG(FATAL) << "call to miopenConvolutionForwardCompileSolution "
"failed: "
<< ToString(status);
return false;
}
out_algorithms->emplace_back(
GetProfileResultFromConvSolution(best_solution));
} else {
for (int i = 0; i < solutionCount; i++) {
miopenConvSolution_t solution = solutions[i];
VLOG(kImmediateModeVlogLevel)
<< "solution " << i
<< " (time, mem, id, algo) = " << solution.time << ", "
<< solution.workspace_size << ", " << solution.solution_id << ", "
<< ToString(solution.algorithm);
status = wrap::miopenConvolutionForwardCompileSolution(
miopen.handle(), filter.handle(), input_nd.handle(),
conv.handle(), output_nd.handle(), solution.solution_id);
if (status != miopenStatusSuccess) {
LOG(FATAL)
<< "call to miopenConvolutionForwardCompileSolution failed: "
<< ToString(status);
return false;
}
out_algorithms->emplace_back(
GetProfileResultFromConvSolution(solution));
}
}
break;
}
case dnn::ConvolutionKind::BACKWARD_DATA: {
auto status = wrap::miopenConvolutionBackwardDataGetSolution(
miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
input_nd.handle(), maxSolutionCount, &solutionCount, solutions.get());
if (status != miopenStatusSuccess) {
LOG(FATAL)
<< "call to miopenConvolutionBackwardDataGetSolution failed: "
<< ToString(status);
return false;
}
VLOG(kImmediateModeVlogLevel)
<< "Number of conv solutions actual: " << solutionCount;
if (mimic_find_mode) {
miopenConvSolution_t best_solution = solutions[0];
for (int i = 1; i < solutionCount; i++) {
miopenConvSolution_t solution = solutions[i];
if (solution.time < best_solution.time) {
best_solution = solution;
}
}
VLOG(kImmediateModeVlogLevel)
<< "Best Solution (id, algo) = " << best_solution.solution_id
<< ", " << ToString(best_solution.algorithm);
status = wrap::miopenConvolutionBackwardDataCompileSolution(
miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
input_nd.handle(), best_solution.solution_id);
if (status != miopenStatusSuccess) {
LOG(FATAL) << "call to miopenConvolutionBackwardDataCompileSolution "
"failed: "
<< ToString(status);
return false;
}
out_algorithms->emplace_back(
GetProfileResultFromConvSolution(best_solution));
} else {
for (int i = 0; i < solutionCount; i++) {
miopenConvSolution_t solution = solutions[i];
VLOG(kImmediateModeVlogLevel)
<< "solution " << i
<< " (time, mem, id, algo) = " << solution.time << ", "
<< solution.workspace_size << ", " << solution.solution_id << ", "
<< ToString(solution.algorithm);
status = wrap::miopenConvolutionBackwardDataCompileSolution(
miopen.handle(), output_nd.handle(), filter.handle(),
conv.handle(), input_nd.handle(), solution.solution_id);
if (status != miopenStatusSuccess) {
LOG(FATAL)
<< " call to miopenConvolutionBackwardDataCompileSolution "
"failed: "
<< ToString(status);
return false;
}
out_algorithms->emplace_back(
GetProfileResultFromConvSolution(solution));
}
}
break;
}
case dnn::ConvolutionKind::BACKWARD_FILTER: {
auto status = wrap::miopenConvolutionBackwardWeightsGetSolution(
miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
filter.handle(), maxSolutionCount, &solutionCount, solutions.get());
if (status != miopenStatusSuccess) {
LOG(FATAL)
<< "call to miopenConvolutionBackwardWeightsGetSolution failed: "
<< ToString(status);
return false;
}
VLOG(kImmediateModeVlogLevel)
<< "Number of conv solutions actual: " << solutionCount;
if (mimic_find_mode) {
miopenConvSolution_t best_solution = solutions[0];
for (int i = 1; i < solutionCount; i++) {
miopenConvSolution_t solution = solutions[i];
if (solution.time < best_solution.time) {
best_solution = solution;
}
}
VLOG(kImmediateModeVlogLevel)
<< "Best Solution (id, algo) = " << best_solution.solution_id
<< ", " << ToString(best_solution.algorithm);
status = wrap::miopenConvolutionBackwardWeightsCompileSolution(
miopen.handle(), output_nd.handle(), input_nd.handle(),
conv.handle(), filter.handle(), best_solution.solution_id);
if (status != miopenStatusSuccess) {
LOG(FATAL)
<< "call to miopenConvolutionBackwardWeightsCompileSolution "
"failed: "
<< ToString(status);
return false;
}
out_algorithms->emplace_back(
GetProfileResultFromConvSolution(best_solution));
} else {
for (int i = 0; i < solutionCount; i++) {
miopenConvSolution_t solution = solutions[i];
VLOG(kImmediateModeVlogLevel)
<< "solution " << i
<< " (time, mem, id, algo) = " << solution.time << ", "
<< solution.workspace_size << ", " << solution.solution_id << ", "
<< ToString(solution.algorithm);
status = wrap::miopenConvolutionBackwardWeightsCompileSolution(
miopen.handle(), output_nd.handle(), input_nd.handle(),
conv.handle(), filter.handle(), solution.solution_id);
if (status != miopenStatusSuccess) {
LOG(FATAL)
<< "call to miopenConvolutionBackwardWeightsCompileSolution "
"failed: "
<< ToString(status);
return false;
}
out_algorithms->emplace_back(
GetProfileResultFromConvSolution(solution));
}
}
break;
}
default: {
LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
return false;
break;
}
}
return true;
}
bool MIOpenSupport::GetRnnAlgorithms(
std::vector<dnn::AlgorithmDesc>* out_algorithms) {
// ROCM TODO: implement this with proper MIOpen API

View File

@ -195,6 +195,14 @@ class MIOpenSupport : public dnn::DnnSupport {
bool with_winograd_nonfused, int cc_major, int cc_minor,
std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
bool GetMIOpenConvolveAlgorithms(
dnn::ConvolutionKind kind, Stream* stream, dnn::DataType element_type,
const dnn::BatchDescriptor& input_descriptor,
const dnn::FilterDescriptor& filter_descriptor,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& output_descriptor,
std::vector<dnn::ProfileResult>* out_algorithms) override;
bool GetRnnAlgorithms(
std::vector<dnn::AlgorithmDesc>* out_algorithms) override;

View File

@ -290,6 +290,22 @@ bool StreamExecutor::GetConvolveAlgorithms(
cc_minor, out_algorithms);
}
bool StreamExecutor::GetMIOpenConvolveAlgorithms(
dnn::ConvolutionKind kind, Stream *stream, dnn::DataType element_type,
const dnn::BatchDescriptor &input_descriptor,
const dnn::FilterDescriptor &filter_descriptor,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor,
std::vector<dnn::ProfileResult> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();
if (!dnn_support) {
return false;
}
return dnn_support->GetMIOpenConvolveAlgorithms(
kind, stream, element_type, input_descriptor, filter_descriptor,
convolution_descriptor, output_descriptor, out_algorithms);
}
bool StreamExecutor::GetRnnAlgorithms(
std::vector<dnn::AlgorithmDesc> *out_algorithms) {
dnn::DnnSupport *dnn_support = AsDnn();

View File

@ -372,6 +372,16 @@ class StreamExecutor {
bool GetConvolveAlgorithms(bool with_winograd_nonfused,
std::vector<dnn::AlgorithmDesc> *out_algorithms);
// Returns the list of supported algorithms for the forward convolution
// operation.
bool GetMIOpenConvolveAlgorithms(
dnn::ConvolutionKind kind, Stream *stream, dnn::DataType element_type,
const dnn::BatchDescriptor &input_descriptor,
const dnn::FilterDescriptor &filter_descriptor,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &output_descriptor,
std::vector<dnn::ProfileResult> *out_algorithms);
// Returns the list of supported algorithms for rnn operation.
bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);