Merge pull request #35503 from ROCmSoftwarePlatform:google_upstream_rocm_miopen_immediate_mode
PiperOrigin-RevId: 289053613 Change-Id: I233d95adc3aa888460bd39a07fd7e168fea14846
This commit is contained in:
commit
c1971ab97c
@ -593,6 +593,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@llvm-project//llvm:core",
|
||||
],
|
||||
|
@ -117,6 +117,29 @@ std::vector<AlgorithmDesc> GetAlgorithms(CudnnConvKind kind,
|
||||
return algorithms;
|
||||
}
|
||||
|
||||
StatusOr<std::vector<se::dnn::ProfileResult>> GetAlgorithms(
|
||||
const HloCustomCallInstruction* conv,
|
||||
absl::Span<se::DeviceMemoryBase> operand_buffers,
|
||||
se::DeviceMemoryBase result_buffer, se::StreamExecutor* stream_exec,
|
||||
se::Stream* stream) {
|
||||
std::vector<se::dnn::ProfileResult> algorithms;
|
||||
|
||||
TF_ASSIGN_OR_RETURN(se::dnn::ConvolutionKind kind,
|
||||
GetDnnConvolutionKind(conv));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(se::dnn::DataType dtype, GetDnnDataType(conv));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(GpuConvParams params,
|
||||
GetGpuConvParams(conv, operand_buffers, result_buffer));
|
||||
|
||||
bool succ = stream_exec->GetMIOpenConvolveAlgorithms(
|
||||
kind, stream, dtype, params.input_descriptor, params.filter_descriptor,
|
||||
params.conv_desc, params.output_descriptor, &algorithms);
|
||||
DCHECK(succ);
|
||||
|
||||
return algorithms;
|
||||
}
|
||||
|
||||
string AlgorithmToString(const AlgorithmDesc& algo) {
|
||||
if (algo.tensor_ops_enabled()) {
|
||||
return absl::StrCat(algo.algo_id(), "+TC");
|
||||
@ -611,33 +634,72 @@ GpuConvAlgorithmPicker::PickBestAlgorithmNoCacheRocm(
|
||||
ShapeUtil::ByteSizeOf(instr->shape().tuple_shapes(0))));
|
||||
initialize_buffer(result_buffer);
|
||||
|
||||
ScratchAllocator scratch_allocator(device_ordinal, allocator);
|
||||
se::dnn::ProfileResult profile_result;
|
||||
VLOG(3) << "Auto-tuning for " << instr->ToString();
|
||||
RunConvOptions options;
|
||||
options.profile_result = &profile_result;
|
||||
TF_ASSIGN_OR_RETURN(std::vector<se::dnn::ProfileResult> algorithms,
|
||||
GetAlgorithms(instr, absl::MakeSpan(operand_buffers),
|
||||
result_buffer, stream_exec_, stream));
|
||||
|
||||
// ROCm: Set the overriding algorithm to empty to remind cudnn_conv_runner
|
||||
// that the AlgorithmConfig in running convolution needs to be empty
|
||||
options.algo_override = se::dnn::AlgorithmDesc();
|
||||
std::vector<AutotuneResult> profile_results;
|
||||
|
||||
bool launch_ok =
|
||||
RunGpuConv(instr, absl::MakeSpan(operand_buffers), result_buffer,
|
||||
&scratch_allocator, stream, options)
|
||||
.ok();
|
||||
|
||||
AutotuneResult best_result;
|
||||
if (launch_ok && profile_result.is_valid()) {
|
||||
best_result.mutable_conv()->set_algorithm(
|
||||
profile_result.algorithm().algo_id());
|
||||
best_result.mutable_conv()->set_tensor_ops_enabled(
|
||||
if (algorithms.size() == 1) {
|
||||
auto profile_result = algorithms[0];
|
||||
profile_results.emplace_back();
|
||||
auto& result = profile_results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_result.algorithm().algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_result.algorithm().tensor_ops_enabled());
|
||||
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
|
||||
best_result.set_scratch_bytes(scratch_bytes_used);
|
||||
*best_result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
|
||||
return best_result;
|
||||
result.set_scratch_bytes(profile_result.scratch_size());
|
||||
*result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
} else {
|
||||
for (const auto& miopen_alg : algorithms) {
|
||||
const auto& alg = miopen_alg.algorithm();
|
||||
XLA_SCOPED_LOGGING_TIMER_LEVEL(
|
||||
absl::StrCat("CudnnConvAlgorithmPicker::PickBestAlgorithm algo ",
|
||||
AlgorithmToString(alg)),
|
||||
2);
|
||||
|
||||
ScratchAllocator scratch_allocator(device_ordinal, allocator);
|
||||
se::dnn::ProfileResult profile_result;
|
||||
VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for "
|
||||
<< instr->ToString();
|
||||
|
||||
// Use assignment instead of brace-list to make GCC 4.9 happy.
|
||||
RunConvOptions options;
|
||||
options.profile_result = &profile_result;
|
||||
options.algo_override = alg;
|
||||
Status launch_status =
|
||||
RunGpuConv(instr, absl::MakeSpan(operand_buffers), result_buffer,
|
||||
&scratch_allocator, stream, options);
|
||||
|
||||
if (!launch_status.ok()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!profile_result.is_valid()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
profile_results.emplace_back();
|
||||
AutotuneResult& result = profile_results.back();
|
||||
result.mutable_conv()->set_algorithm(alg.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled());
|
||||
|
||||
int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes();
|
||||
result.set_scratch_bytes(scratch_bytes_used);
|
||||
*result.mutable_run_time() = tensorflow::proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
const auto& best_result = absl::c_min_element(
|
||||
profile_results,
|
||||
[&](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
||||
return tensorflow::proto_utils::FromDurationProto(lhs.run_time()) <
|
||||
tensorflow::proto_utils::FromDurationProto(rhs.run_time());
|
||||
});
|
||||
|
||||
if (best_result != profile_results.end()) {
|
||||
return *best_result;
|
||||
}
|
||||
|
||||
return InternalError(
|
||||
|
@ -223,17 +223,7 @@ Status RunGpuConvImpl(const GpuConvParams& params,
|
||||
auto output_buf = se::DeviceMemory<OutputType>(params.output_buf);
|
||||
AlgorithmConfig algorithm = params.algorithm;
|
||||
|
||||
// in ROCm mode, the first call to run the convolution needs to trigger the
|
||||
// code that calls miopenFind* API. That triggger is implicit, it is based
|
||||
// on whether or not the AlgorithmConfig::algorithm is empty! So for the
|
||||
// first call we need to ensure that the AlgorithmConfig::algorithm is
|
||||
// empty. For all subsequent calls, we should use the value retrieved from
|
||||
// the backend_config
|
||||
if ((stream->parent()->platform_kind() == se::PlatformKind::kROCm) &&
|
||||
(options.algo_override.has_value()) &&
|
||||
(*options.algo_override == se::dnn::AlgorithmDesc())) {
|
||||
algorithm = AlgorithmConfig();
|
||||
} else if (options.algo_override.has_value()) {
|
||||
if (options.algo_override.has_value()) {
|
||||
algorithm = AlgorithmConfig(*options.algo_override);
|
||||
}
|
||||
|
||||
|
@ -427,6 +427,39 @@ StatusOr<CudnnConvKind> GetCudnnConvKind(
|
||||
return InternalError("Unexpected call target: %s", target);
|
||||
}
|
||||
|
||||
StatusOr<se::dnn::ConvolutionKind> GetDnnConvolutionKind(
|
||||
const HloCustomCallInstruction* instr) {
|
||||
absl::string_view target = instr->custom_call_target();
|
||||
if (target == kCudnnConvForwardCallTarget) {
|
||||
return se::dnn::ConvolutionKind::FORWARD;
|
||||
}
|
||||
if (target == kCudnnConvBackwardInputCallTarget) {
|
||||
return se::dnn::ConvolutionKind::BACKWARD_DATA;
|
||||
}
|
||||
if (target == kCudnnConvBackwardFilterCallTarget) {
|
||||
return se::dnn::ConvolutionKind::BACKWARD_FILTER;
|
||||
}
|
||||
return InternalError("Unexpected call target: %s", target);
|
||||
}
|
||||
|
||||
StatusOr<se::dnn::DataType> GetDnnDataType(
|
||||
const HloCustomCallInstruction* conv) {
|
||||
PrimitiveType output_primitive_type =
|
||||
conv->shape().tuple_shapes(0).element_type();
|
||||
switch (output_primitive_type) {
|
||||
case F16:
|
||||
return se::dnn::ToDataType<Eigen::half>::value;
|
||||
case F32:
|
||||
return se::dnn::ToDataType<float>::value;
|
||||
case F64:
|
||||
return se::dnn::ToDataType<double>::value;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
return InternalError("Unsupported convolution datatype : %s",
|
||||
conv->ToString());
|
||||
}
|
||||
|
||||
string CudnnConvKindToString(CudnnConvKind kind) {
|
||||
switch (kind) {
|
||||
case CudnnConvKind::kForward:
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "llvm/IR/Value.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
|
||||
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
|
||||
|
||||
// TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they
|
||||
// don't belong in "ir_emission_utils".
|
||||
@ -53,6 +54,12 @@ enum class CudnnConvKind {
|
||||
|
||||
StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr);
|
||||
|
||||
StatusOr<se::dnn::ConvolutionKind> GetDnnConvolutionKind(
|
||||
const HloCustomCallInstruction* instr);
|
||||
|
||||
StatusOr<se::dnn::DataType> GetDnnDataType(
|
||||
const HloCustomCallInstruction* conv);
|
||||
|
||||
// Converts a CudnnConvKind value to a string.
|
||||
string CudnnConvKindToString(CudnnConvKind kind);
|
||||
|
||||
|
@ -20,10 +20,6 @@ load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl")
|
||||
load("//tensorflow:tensorflow.bzl", "if_nccl")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests")
|
||||
load(
|
||||
"//tensorflow/core/platform/default:cuda_build_defs.bzl",
|
||||
"if_cuda_is_configured",
|
||||
)
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
@ -522,7 +518,7 @@ cc_library(
|
||||
|
||||
tf_cuda_library(
|
||||
name = "gpu_utils",
|
||||
srcs = if_cuda_is_configured(["gpu_utils.cc"]),
|
||||
srcs = if_cuda_or_rocm(["gpu_utils.cc"]),
|
||||
hdrs = ["gpu_utils.h"],
|
||||
deps = [
|
||||
":gpu_util_hdrs",
|
||||
|
@ -1033,28 +1033,66 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
|
||||
CheckRedzones(rz_allocator, &result);
|
||||
}
|
||||
}
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
std::vector<ProfileResult> algorithms;
|
||||
OP_REQUIRES(ctx,
|
||||
stream->parent()->GetMIOpenConvolveAlgorithms(
|
||||
se::dnn::ConvolutionKind::BACKWARD_FILTER, stream,
|
||||
se::dnn::ToDataType<T>::value, input_desc, filter_desc,
|
||||
conv_desc, output_desc, &algorithms),
|
||||
errors::Unknown(
|
||||
"Failed to get convolution algorithm. This is probably "
|
||||
"because MIOpen failed to initialize, so try looking to "
|
||||
"see if a warning log message was printed above."));
|
||||
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
if (algorithms.size() == 1) {
|
||||
auto profile_result = algorithms[0];
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(
|
||||
profile_result.algorithm().algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_result.algorithm().tensor_ops_enabled());
|
||||
|
||||
result.set_scratch_bytes(profile_result.scratch_size());
|
||||
*result.mutable_run_time() = proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
} else {
|
||||
for (auto miopen_algorithm : algorithms) {
|
||||
auto profile_algorithm = miopen_algorithm.algorithm();
|
||||
DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
|
||||
ctx);
|
||||
ProfileResult profile_result;
|
||||
bool miopen_launch_status = true;
|
||||
miopen_launch_status =
|
||||
stream
|
||||
->ThenConvolveBackwardFilterWithAlgorithm(
|
||||
input_desc, input_ptr, output_desc, out_backprop_ptr,
|
||||
conv_desc, filter_desc, &filter_backprop_ptr,
|
||||
&scratch_allocator, AlgorithmConfig(profile_algorithm),
|
||||
&profile_result)
|
||||
.ok();
|
||||
|
||||
if (miopen_launch_status && profile_result.is_valid()) {
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
|
||||
*result.mutable_run_time() = proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
LogConvAutotuneResults(se::dnn::ConvolutionKind::BACKWARD_FILTER,
|
||||
se::dnn::ToDataType<T>::value, input_ptr,
|
||||
filter_backprop_ptr_rz, out_backprop_ptr, input_desc,
|
||||
filter_backprop_ptr, out_backprop_ptr, input_desc,
|
||||
filter_desc, output_desc, conv_desc,
|
||||
stream->parent(), results);
|
||||
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
ProfileResult best_result;
|
||||
DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
|
||||
ctx);
|
||||
bool miopen_find_status =
|
||||
stream
|
||||
->ThenConvolveBackwardFilterWithAlgorithm(
|
||||
input_desc, input_ptr, output_desc, out_backprop_ptr, conv_desc,
|
||||
filter_desc, &filter_backprop_ptr, &scratch_allocator,
|
||||
AlgorithmConfig(), &best_result)
|
||||
.ok();
|
||||
OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(),
|
||||
errors::NotFound("Failed to find backward filter algorithm!"));
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
algorithm_config.set_scratch_size(best_result.scratch_size());
|
||||
#endif
|
||||
AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
|
@ -1199,29 +1199,64 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
|
||||
CheckRedzones(rz_allocator, &result);
|
||||
}
|
||||
}
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
std::vector<ProfileResult> algorithms;
|
||||
OP_REQUIRES(ctx,
|
||||
stream->parent()->GetMIOpenConvolveAlgorithms(
|
||||
se::dnn::ConvolutionKind::BACKWARD_DATA, stream,
|
||||
se::dnn::ToDataType<T>::value, input_desc, filter_desc,
|
||||
conv_desc, output_desc, &algorithms),
|
||||
errors::Unknown(
|
||||
"Failed to get convolution algorithm. This is probably "
|
||||
"because MIOpen failed to initialize, so try looking to "
|
||||
"see if a warning log message was printed above."));
|
||||
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
if (algorithms.size() == 1) {
|
||||
auto profile_result = algorithms[0];
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(
|
||||
profile_result.algorithm().algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_result.algorithm().tensor_ops_enabled());
|
||||
|
||||
result.set_scratch_bytes(profile_result.scratch_size());
|
||||
*result.mutable_run_time() = proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
} else {
|
||||
for (auto miopen_algorithm : algorithms) {
|
||||
auto profile_algorithm = miopen_algorithm.algorithm();
|
||||
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
|
||||
ctx);
|
||||
ProfileResult profile_result;
|
||||
bool miopen_launch_status = true;
|
||||
miopen_launch_status =
|
||||
stream
|
||||
->ThenConvolveBackwardDataWithAlgorithm(
|
||||
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
|
||||
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
|
||||
AlgorithmConfig(profile_algorithm), &profile_result)
|
||||
.ok();
|
||||
|
||||
if (miopen_launch_status && profile_result.is_valid()) {
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
|
||||
*result.mutable_run_time() = proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
LogConvAutotuneResults(
|
||||
se::dnn::ConvolutionKind::BACKWARD_DATA, se::dnn::ToDataType<T>::value,
|
||||
in_backprop_ptr, filter_ptr, out_backprop_ptr, input_desc, filter_desc,
|
||||
output_desc, conv_desc, stream->parent(), results);
|
||||
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
// MIOpen has its own Find and autotuner so use it here, passing
|
||||
// default AlgorithmConfig to force a search
|
||||
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize, ctx);
|
||||
ProfileResult best_result;
|
||||
bool miopen_find_status =
|
||||
stream
|
||||
->ThenConvolveBackwardDataWithAlgorithm(
|
||||
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
|
||||
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
|
||||
AlgorithmConfig(), &best_result)
|
||||
.ok();
|
||||
OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(),
|
||||
errors::NotFound("Failed to find backwards-data algorithm!"));
|
||||
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
algorithm_config.set_scratch_size(best_result.scratch_size());
|
||||
#endif
|
||||
AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
|
@ -1433,6 +1433,51 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
std::vector<ProfileResult> algorithms;
|
||||
CHECK(stream->parent()->GetMIOpenConvolveAlgorithms(
|
||||
se::dnn::ConvolutionKind::BACKWARD_DATA, stream,
|
||||
se::dnn::ToDataType<T>::value, input_desc, filter_desc, conv_desc,
|
||||
output_desc, &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
for (auto miopen_algorithm : algorithms) {
|
||||
auto profile_algorithm = miopen_algorithm.algorithm();
|
||||
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
|
||||
context);
|
||||
ProfileResult profile_result;
|
||||
bool miopen_launch_status =
|
||||
stream
|
||||
->ThenConvolveBackwardDataWithAlgorithm(
|
||||
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
|
||||
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
|
||||
AlgorithmConfig(profile_algorithm), &profile_result)
|
||||
.ok();
|
||||
if (miopen_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
|
||||
*result.mutable_run_time() = proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
LogConvAutotuneResults(se::dnn::ConvolutionKind::BACKWARD_DATA,
|
||||
se::dnn::ToDataType<T>::value, in_backprop_ptr,
|
||||
filter_ptr, out_backprop_ptr, input_desc,
|
||||
@ -1448,22 +1493,6 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
DnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
|
||||
context);
|
||||
ProfileResult best_result;
|
||||
bool miopen_find_status =
|
||||
stream
|
||||
->ThenConvolveBackwardDataWithAlgorithm(
|
||||
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
|
||||
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator,
|
||||
AlgorithmConfig(), &best_result)
|
||||
.ok();
|
||||
OP_REQUIRES(context, miopen_find_status && best_result.is_valid(),
|
||||
errors::NotFound("Failed to find backward data algorithm!"));
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
algorithm_config.set_scratch_size(best_result.scratch_size());
|
||||
#endif
|
||||
AutoTuneConv3dBwdData::GetInstance()->Insert(conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
@ -1864,6 +1893,46 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
std::vector<ProfileResult> algorithms;
|
||||
CHECK(stream->parent()->GetMIOpenConvolveAlgorithms(
|
||||
se::dnn::ConvolutionKind::BACKWARD_FILTER, stream,
|
||||
se::dnn::ToDataType<T>::value, input_desc, filter_desc, conv_desc,
|
||||
output_desc, &algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
if (algorithms.size() == 1) {
|
||||
best_result = algorithms[0];
|
||||
} else {
|
||||
for (auto miopen_algorithm : algorithms) {
|
||||
auto profile_algorithm = miopen_algorithm.algorithm();
|
||||
DnnScratchAllocator scratch_allocator(
|
||||
ConvolveBackwardFilterScratchSize, context);
|
||||
ProfileResult profile_result;
|
||||
bool cudnn_launch_status =
|
||||
stream
|
||||
->ThenConvolveBackwardFilterWithAlgorithm(
|
||||
input_desc, input_ptr, output_desc, out_backprop_ptr,
|
||||
conv_desc, filter_desc, &filter_backprop_ptr,
|
||||
&scratch_allocator, AlgorithmConfig(profile_algorithm),
|
||||
&profile_result)
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
OP_REQUIRES(context,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
@ -1874,23 +1943,6 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
|
||||
context);
|
||||
ProfileResult best_result;
|
||||
bool miopen_find_status =
|
||||
stream
|
||||
->ThenConvolveBackwardFilterWithAlgorithm(
|
||||
input_desc, input_ptr, output_desc, out_backprop_ptr,
|
||||
conv_desc, filter_desc, &filter_backprop_ptr,
|
||||
&scratch_allocator, AlgorithmConfig(), &best_result)
|
||||
.ok();
|
||||
OP_REQUIRES(
|
||||
context, miopen_find_status && best_result.is_valid(),
|
||||
errors::NotFound("Failed to find backward filter algorithm!"));
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
algorithm_config.set_scratch_size(best_result.scratch_size());
|
||||
#endif
|
||||
AutoTuneConv3dBwdFilter::GetInstance()->Insert(conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
|
@ -1039,28 +1039,65 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
CheckRedzones(rz_allocator, &result);
|
||||
}
|
||||
}
|
||||
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
std::vector<ProfileResult> algorithms;
|
||||
OP_REQUIRES(ctx,
|
||||
stream->parent()->GetMIOpenConvolveAlgorithms(
|
||||
se::dnn::ConvolutionKind::FORWARD, stream,
|
||||
se::dnn::ToDataType<T>::value, input_desc, filter_desc,
|
||||
conv_desc, output_desc, &algorithms),
|
||||
errors::Unknown(
|
||||
"Failed to get convolution algorithm. This is probably "
|
||||
"because MIOpen failed to initialize, so try looking to "
|
||||
"see if a warning log message was printed above."));
|
||||
se::DeviceMemory<T> output_tensor = output_ptr;
|
||||
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
if (algorithms.size() == 1) {
|
||||
auto profile_result = algorithms[0];
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(
|
||||
profile_result.algorithm().algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_result.algorithm().tensor_ops_enabled());
|
||||
|
||||
result.set_scratch_bytes(profile_result.scratch_size());
|
||||
*result.mutable_run_time() = proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
} else {
|
||||
for (auto miopen_algorithm : algorithms) {
|
||||
auto profile_algorithm = miopen_algorithm.algorithm();
|
||||
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
|
||||
ProfileResult profile_result;
|
||||
bool miopen_launch_status = false;
|
||||
miopen_launch_status =
|
||||
stream
|
||||
->ThenConvolveWithAlgorithm(
|
||||
input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
|
||||
output_desc, &output_ptr, &scratch_allocator,
|
||||
AlgorithmConfig(profile_algorithm), &profile_result)
|
||||
.ok();
|
||||
if (miopen_launch_status && profile_result.is_valid()) {
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
|
||||
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
|
||||
*result.mutable_run_time() = proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
LogConvAutotuneResults(se::dnn::ConvolutionKind::FORWARD,
|
||||
se::dnn::ToDataType<T>::value, input_ptr, filter_ptr,
|
||||
output_tensor, input_desc, filter_desc, output_desc,
|
||||
conv_desc, stream->parent(), results);
|
||||
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
|
||||
ProfileResult best_result;
|
||||
bool miopen_find_status =
|
||||
stream
|
||||
->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc,
|
||||
filter_ptr, conv_desc, output_desc,
|
||||
&output_ptr, &scratch_allocator,
|
||||
AlgorithmConfig(), &best_result)
|
||||
.ok();
|
||||
|
||||
OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(),
|
||||
errors::NotFound("Failed to find conv algorithm!"));
|
||||
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
algorithm_config.set_scratch_size(best_result.scratch_size());
|
||||
#endif
|
||||
AutoTuneConv::GetInstance()->Insert(conv_parameters, algorithm_config);
|
||||
}
|
||||
|
||||
|
@ -504,26 +504,63 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
std::vector<ProfileResult> algorithms;
|
||||
OP_REQUIRES(ctx,
|
||||
stream->parent()->GetMIOpenConvolveAlgorithms(
|
||||
se::dnn::ConvolutionKind::FORWARD, stream,
|
||||
se::dnn::ToDataType<T>::value, input_desc, filter_desc,
|
||||
conv_desc, output_desc, &algorithms),
|
||||
errors::Unknown(
|
||||
"Failed to get convolution algorithm. This is probably "
|
||||
"because MIOpen failed to initialize, so try looking to "
|
||||
"see if a warning log message was printed above."));
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
if (algorithms.size() == 1) {
|
||||
auto profile_result = algorithms[0];
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(
|
||||
profile_result.algorithm().algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_result.algorithm().tensor_ops_enabled());
|
||||
|
||||
result.set_scratch_bytes(profile_result.scratch_size());
|
||||
*result.mutable_run_time() = proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
} else {
|
||||
for (auto miopen_algorithm : algorithms) {
|
||||
auto profile_algorithm = miopen_algorithm.algorithm();
|
||||
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
|
||||
ProfileResult profile_result;
|
||||
bool miopen_launch_status =
|
||||
stream
|
||||
->ThenConvolveWithAlgorithm(
|
||||
input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
|
||||
output_desc, &output_ptr, &scratch_allocator,
|
||||
AlgorithmConfig(profile_algorithm), &profile_result)
|
||||
.ok();
|
||||
if (miopen_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.set_scratch_bytes(scratch_allocator.TotalByteSize());
|
||||
*result.mutable_run_time() = proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
LogConvAutotuneResults(se::dnn::ConvolutionKind::FORWARD,
|
||||
se::dnn::ToDataType<T>::value, input_ptr,
|
||||
filter_ptr, output_ptr, input_desc, filter_desc,
|
||||
output_desc, conv_desc, stream->parent(), results);
|
||||
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
#elif TENSORFLOW_USE_ROCM
|
||||
ProfileResult best_result;
|
||||
DnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
|
||||
bool miopen_find_status =
|
||||
stream
|
||||
->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc,
|
||||
filter_ptr, conv_desc, output_desc,
|
||||
&output_ptr, &scratch_allocator,
|
||||
AlgorithmConfig(), &best_result)
|
||||
.ok();
|
||||
OP_REQUIRES(ctx, miopen_find_status && best_result.is_valid(),
|
||||
errors::NotFound("Failed to find conv algorithm!"));
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
algorithm_config.set_scratch_size(best_result.scratch_size());
|
||||
#endif
|
||||
AutoTuneConv3d::GetInstance()->Insert(conv_parameters, algorithm_config);
|
||||
}
|
||||
|
||||
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
#include <iterator>
|
||||
|
||||
@ -249,4 +249,4 @@ Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
@ -41,6 +41,17 @@ bool DnnSupport::GetConvolveAlgorithms(
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DnnSupport::GetMIOpenConvolveAlgorithms(
|
||||
dnn::ConvolutionKind /*kind*/, Stream* /*stream*/,
|
||||
dnn::DataType /*element_type*/,
|
||||
const dnn::BatchDescriptor& /*input_descriptor*/,
|
||||
const dnn::FilterDescriptor& /*filter_descriptor*/,
|
||||
const dnn::ConvolutionDescriptor& /*convolution_descriptor*/,
|
||||
const dnn::BatchDescriptor& /*output_descriptor*/,
|
||||
std::vector<ProfileResult>* /*out_algorithms*/) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool DnnSupport::GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms) {
|
||||
return false;
|
||||
}
|
||||
|
@ -1352,6 +1352,14 @@ class DnnSupport {
|
||||
bool with_winograd_nonfused, int cc_major, int cc_minor,
|
||||
std::vector<AlgorithmDesc>* out_algorithms);
|
||||
|
||||
virtual bool GetMIOpenConvolveAlgorithms(
|
||||
dnn::ConvolutionKind kind, Stream* stream, dnn::DataType element_type,
|
||||
const dnn::BatchDescriptor& input_descriptor,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
std::vector<ProfileResult>* out_algorithms);
|
||||
|
||||
// Returns a list of supported rnn algorithms.
|
||||
virtual bool GetRnnAlgorithms(std::vector<AlgorithmDesc>* out_algorithms);
|
||||
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "rocm/include/miopen/miopen.h"
|
||||
#include "tensorflow/core/lib/hash/hash.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
#include "tensorflow/stream_executor/dnn.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_activation.h"
|
||||
#include "tensorflow/stream_executor/gpu/gpu_driver.h"
|
||||
@ -53,6 +54,8 @@ NarrowT CheckedNarrowing(const WideT& wide) {
|
||||
return narrow;
|
||||
}
|
||||
|
||||
const int kImmediateModeVlogLevel = 3;
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace stream_executor {
|
||||
@ -91,6 +94,24 @@ string ToString(miopenStatus_t status) {
|
||||
}
|
||||
}
|
||||
|
||||
string ToString(miopenConvAlgorithm_t algorithm) {
|
||||
string s;
|
||||
switch (algorithm) {
|
||||
case miopenConvolutionAlgoGEMM:
|
||||
s = "GEMM";
|
||||
break;
|
||||
case miopenConvolutionAlgoDirect:
|
||||
s = "Direct";
|
||||
break;
|
||||
case miopenConvolutionAlgoFFT:
|
||||
s = "FFT";
|
||||
break;
|
||||
case miopenConvolutionAlgoWinograd:
|
||||
s = "Winograd";
|
||||
break;
|
||||
}
|
||||
return s;
|
||||
}
|
||||
// RAII wrapper for all calls to MIOpen with a MIOpen handle argument.
|
||||
//
|
||||
// See MIOpenAccess::GetHandle() for details.
|
||||
@ -156,95 +177,110 @@ namespace wrap {
|
||||
#endif
|
||||
|
||||
// clang-format off
|
||||
#define MIOPEN_DNN_ROUTINE_EACH(__macro) \
|
||||
__macro(miopenBatchNormalizationBackward) \
|
||||
__macro(miopenBatchNormalizationForwardInference) \
|
||||
__macro(miopenBatchNormalizationForwardTraining) \
|
||||
__macro(miopenGetConvolutionForwardOutputDim) \
|
||||
__macro(miopenGetConvolutionNdForwardOutputDim) \
|
||||
__macro(miopenFindConvolutionForwardAlgorithm) \
|
||||
__macro(miopenCreateTensorDescriptor) \
|
||||
__macro(miopenDestroyTensorDescriptor) \
|
||||
__macro(miopenSet2dPoolingDescriptor) \
|
||||
__macro(miopenSetLRNDescriptor) \
|
||||
__macro(miopenLRNGetWorkSpaceSize) \
|
||||
__macro(miopenCreateConvolutionDescriptor) \
|
||||
__macro(miopenCreatePoolingDescriptor) \
|
||||
__macro(miopenDestroyPoolingDescriptor) \
|
||||
__macro(miopenCreateLRNDescriptor) \
|
||||
__macro(miopenDestroyLRNDescriptor) \
|
||||
__macro(miopenDestroyConvolutionDescriptor) \
|
||||
__macro(miopenCreateWithStream) \
|
||||
__macro(miopenDestroy) \
|
||||
__macro(miopenSetStream) \
|
||||
__macro(miopenSetAllocator) \
|
||||
__macro(miopenActivationForward) \
|
||||
__macro(miopenConvolutionForward) \
|
||||
__macro(miopenConvolutionBackwardBias) \
|
||||
__macro(miopenConvolutionForwardGetWorkSpaceSize) \
|
||||
__macro(miopenInitConvolutionDescriptor) \
|
||||
__macro(miopenInitConvolutionNdDescriptor) \
|
||||
__macro(miopenGetConvolutionDescriptor) \
|
||||
__macro(miopenGetConvolutionNdDescriptor) \
|
||||
__macro(miopenSetConvolutionGroupCount) \
|
||||
__macro(miopenSet4dTensorDescriptor) \
|
||||
__macro(miopenGetTensorDescriptor) \
|
||||
__macro(miopenSetTensorDescriptor) \
|
||||
__macro(miopenGetTensorDescriptorSize) \
|
||||
__macro(miopenPoolingForward) \
|
||||
__macro(miopenPoolingGetWorkSpaceSize) \
|
||||
__macro(miopenPoolingBackward) \
|
||||
__macro(miopenLRNForward) \
|
||||
__macro(miopenLRNBackward) \
|
||||
__macro(miopenOpTensor) \
|
||||
__macro(miopenConvolutionBackwardData) \
|
||||
__macro(miopenConvolutionBackwardWeights) \
|
||||
__macro(miopenConvolutionBackwardWeightsGetWorkSpaceSize)\
|
||||
__macro(miopenFindConvolutionBackwardDataAlgorithm) \
|
||||
__macro(miopenFindConvolutionBackwardWeightsAlgorithm) \
|
||||
__macro(miopenConvolutionBackwardDataGetWorkSpaceSize) \
|
||||
__macro(miopenCreateRNNDescriptor) \
|
||||
__macro(miopenSetRNNDescriptor) \
|
||||
__macro(miopenDestroyRNNDescriptor) \
|
||||
__macro(miopenGetRNNParamsSize) \
|
||||
__macro(miopenGetRNNLayerParam) \
|
||||
__macro(miopenGetRNNLayerBias) \
|
||||
__macro(miopenGetRNNWorkspaceSize) \
|
||||
__macro(miopenGetRNNTrainingReserveSize) \
|
||||
__macro(miopenRNNForwardInference) \
|
||||
__macro(miopenRNNForwardTraining) \
|
||||
__macro(miopenRNNBackwardData) \
|
||||
__macro(miopenRNNBackwardWeights) \
|
||||
__macro(miopenGetRNNLayerParamOffset) \
|
||||
__macro(miopenGetRNNLayerParamSize) \
|
||||
__macro(miopenGetRNNLayerBiasOffset) \
|
||||
__macro(miopenGetRNNLayerBiasSize) \
|
||||
__macro(miopenGetRNNParamsDescriptor) \
|
||||
__macro(miopenCreateActivationDescriptor) \
|
||||
__macro(miopenSetActivationDescriptor) \
|
||||
__macro(miopenGetActivationDescriptor) \
|
||||
__macro(miopenDestroyActivationDescriptor) \
|
||||
__macro(miopenCreateFusionPlan) \
|
||||
__macro(miopenCreateOpConvForward) \
|
||||
__macro(miopenCreateOpBiasForward) \
|
||||
__macro(miopenCreateOpActivationForward) \
|
||||
__macro(miopenCreateOpActivationBackward) \
|
||||
__macro(miopenCreateOpBatchNormInference) \
|
||||
__macro(miopenCreateOpBatchNormForward) \
|
||||
__macro(miopenCreateOpBatchNormBackward) \
|
||||
__macro(miopenCompileFusionPlan) \
|
||||
__macro(miopenFusionPlanGetOp) \
|
||||
__macro(miopenCreateOperatorArgs) \
|
||||
__macro(miopenSetOpArgsConvForward) \
|
||||
__macro(miopenSetOpArgsBiasForward) \
|
||||
__macro(miopenSetOpArgsActivForward) \
|
||||
__macro(miopenSetOpArgsActivBackward) \
|
||||
__macro(miopenSetOpArgsBatchNormInference) \
|
||||
__macro(miopenSetOpArgsBatchNormForward) \
|
||||
__macro(miopenSetOpArgsBatchNormBackward) \
|
||||
__macro(miopenExecuteFusionPlan) \
|
||||
__macro(miopenDestroyOperatorArgs) \
|
||||
__macro(miopenDestroyFusionPlan)
|
||||
#define MIOPEN_DNN_ROUTINE_EACH(__macro) \
|
||||
__macro(miopenBatchNormalizationBackward) \
|
||||
__macro(miopenBatchNormalizationForwardInference) \
|
||||
__macro(miopenBatchNormalizationForwardTraining) \
|
||||
__macro(miopenGetConvolutionForwardOutputDim) \
|
||||
__macro(miopenGetConvolutionNdForwardOutputDim) \
|
||||
__macro(miopenFindConvolutionForwardAlgorithm) \
|
||||
__macro(miopenCreateTensorDescriptor) \
|
||||
__macro(miopenDestroyTensorDescriptor) \
|
||||
__macro(miopenSet2dPoolingDescriptor) \
|
||||
__macro(miopenSetLRNDescriptor) \
|
||||
__macro(miopenLRNGetWorkSpaceSize) \
|
||||
__macro(miopenCreateConvolutionDescriptor) \
|
||||
__macro(miopenCreatePoolingDescriptor) \
|
||||
__macro(miopenDestroyPoolingDescriptor) \
|
||||
__macro(miopenCreateLRNDescriptor) \
|
||||
__macro(miopenDestroyLRNDescriptor) \
|
||||
__macro(miopenDestroyConvolutionDescriptor) \
|
||||
__macro(miopenCreateWithStream) \
|
||||
__macro(miopenDestroy) \
|
||||
__macro(miopenSetStream) \
|
||||
__macro(miopenSetAllocator) \
|
||||
__macro(miopenActivationForward) \
|
||||
__macro(miopenConvolutionForward) \
|
||||
__macro(miopenConvolutionBackwardBias) \
|
||||
__macro(miopenConvolutionForwardGetWorkSpaceSize) \
|
||||
__macro(miopenInitConvolutionDescriptor) \
|
||||
__macro(miopenInitConvolutionNdDescriptor) \
|
||||
__macro(miopenGetConvolutionDescriptor) \
|
||||
__macro(miopenGetConvolutionNdDescriptor) \
|
||||
__macro(miopenSetConvolutionGroupCount) \
|
||||
__macro(miopenSet4dTensorDescriptor) \
|
||||
__macro(miopenGetTensorDescriptor) \
|
||||
__macro(miopenSetTensorDescriptor) \
|
||||
__macro(miopenGetTensorDescriptorSize) \
|
||||
__macro(miopenPoolingForward) \
|
||||
__macro(miopenPoolingGetWorkSpaceSize) \
|
||||
__macro(miopenPoolingBackward) \
|
||||
__macro(miopenLRNForward) \
|
||||
__macro(miopenLRNBackward) \
|
||||
__macro(miopenOpTensor) \
|
||||
__macro(miopenConvolutionBackwardData) \
|
||||
__macro(miopenConvolutionBackwardWeights) \
|
||||
__macro(miopenConvolutionBackwardWeightsGetWorkSpaceSize) \
|
||||
__macro(miopenFindConvolutionBackwardDataAlgorithm) \
|
||||
__macro(miopenFindConvolutionBackwardWeightsAlgorithm) \
|
||||
__macro(miopenConvolutionBackwardDataGetWorkSpaceSize) \
|
||||
__macro(miopenCreateRNNDescriptor) \
|
||||
__macro(miopenSetRNNDescriptor) \
|
||||
__macro(miopenDestroyRNNDescriptor) \
|
||||
__macro(miopenGetRNNParamsSize) \
|
||||
__macro(miopenGetRNNLayerParam) \
|
||||
__macro(miopenGetRNNLayerBias) \
|
||||
__macro(miopenGetRNNWorkspaceSize) \
|
||||
__macro(miopenGetRNNTrainingReserveSize) \
|
||||
__macro(miopenRNNForwardInference) \
|
||||
__macro(miopenRNNForwardTraining) \
|
||||
__macro(miopenRNNBackwardData) \
|
||||
__macro(miopenRNNBackwardWeights) \
|
||||
__macro(miopenGetRNNLayerParamOffset) \
|
||||
__macro(miopenGetRNNLayerParamSize) \
|
||||
__macro(miopenGetRNNLayerBiasOffset) \
|
||||
__macro(miopenGetRNNLayerBiasSize) \
|
||||
__macro(miopenGetRNNParamsDescriptor) \
|
||||
__macro(miopenCreateActivationDescriptor) \
|
||||
__macro(miopenSetActivationDescriptor) \
|
||||
__macro(miopenGetActivationDescriptor) \
|
||||
__macro(miopenDestroyActivationDescriptor) \
|
||||
__macro(miopenCreateFusionPlan) \
|
||||
__macro(miopenCreateOpConvForward) \
|
||||
__macro(miopenCreateOpBiasForward) \
|
||||
__macro(miopenCreateOpActivationForward) \
|
||||
__macro(miopenCreateOpActivationBackward) \
|
||||
__macro(miopenCreateOpBatchNormInference) \
|
||||
__macro(miopenCreateOpBatchNormForward) \
|
||||
__macro(miopenCreateOpBatchNormBackward) \
|
||||
__macro(miopenCompileFusionPlan) \
|
||||
__macro(miopenFusionPlanGetOp) \
|
||||
__macro(miopenCreateOperatorArgs) \
|
||||
__macro(miopenSetOpArgsConvForward) \
|
||||
__macro(miopenSetOpArgsBiasForward) \
|
||||
__macro(miopenSetOpArgsActivForward) \
|
||||
__macro(miopenSetOpArgsActivBackward) \
|
||||
__macro(miopenSetOpArgsBatchNormInference) \
|
||||
__macro(miopenSetOpArgsBatchNormForward) \
|
||||
__macro(miopenSetOpArgsBatchNormBackward) \
|
||||
__macro(miopenExecuteFusionPlan) \
|
||||
__macro(miopenDestroyOperatorArgs) \
|
||||
__macro(miopenDestroyFusionPlan) \
|
||||
__macro(miopenConvolutionForwardGetSolutionCount) \
|
||||
__macro(miopenConvolutionForwardGetSolution) \
|
||||
__macro(miopenConvolutionForwardGetSolutionWorkspaceSize) \
|
||||
__macro(miopenConvolutionForwardCompileSolution) \
|
||||
__macro(miopenConvolutionForwardImmediate) \
|
||||
__macro(miopenConvolutionBackwardDataGetSolutionCount) \
|
||||
__macro(miopenConvolutionBackwardDataGetSolution) \
|
||||
__macro(miopenConvolutionBackwardDataGetSolutionWorkspaceSize) \
|
||||
__macro(miopenConvolutionBackwardDataCompileSolution) \
|
||||
__macro(miopenConvolutionBackwardDataImmediate) \
|
||||
__macro(miopenConvolutionBackwardWeightsGetSolutionCount) \
|
||||
__macro(miopenConvolutionBackwardWeightsGetSolution) \
|
||||
__macro(miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize) \
|
||||
__macro(miopenConvolutionBackwardWeightsCompileSolution) \
|
||||
__macro(miopenConvolutionBackwardWeightsImmediate)
|
||||
|
||||
// clang-format on
|
||||
|
||||
@ -389,6 +425,15 @@ absl::Mutex CachedFusionPlans::cached_plans_mutex;
|
||||
std::map<uint64, miopenFusionPlanDescriptor_t> CachedFusionPlans::cached_plans;
|
||||
std::set<uint64> CachedFusionPlans::unsupported_plans;
|
||||
|
||||
dnn::ProfileResult GetProfileResultFromConvSolution(
|
||||
miopenConvSolution_t solution) {
|
||||
dnn::ProfileResult profile_result;
|
||||
profile_result.set_algorithm({solution.solution_id, false});
|
||||
profile_result.set_elapsed_time_in_ms(solution.time);
|
||||
profile_result.set_scratch_size(solution.workspace_size);
|
||||
return profile_result;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace {
|
||||
@ -2617,126 +2662,73 @@ port::Status MIOpenSupport::DoPrepareForConvolution(
|
||||
|
||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||
|
||||
absl::optional<dnn::AlgorithmDesc> algo_desc = algorithm_config.algorithm();
|
||||
size_t scratch_memory_size;
|
||||
absl::optional<dnn::AlgorithmDesc> input_algo_desc =
|
||||
algorithm_config.algorithm();
|
||||
|
||||
if (!algo_desc.has_value()) {
|
||||
// With the default algorithm, use MIOpen's heuristics.
|
||||
assert(scratch_allocator);
|
||||
assert(input_algo_desc.has_value());
|
||||
|
||||
DeviceMemory<uint8> scratch_memory_temp;
|
||||
MIOpenAllocatorContext mac(scratch_allocator, stream);
|
||||
wrap::miopenSetAllocator(miopen.handle(), MIOpenAllocatorCallback,
|
||||
MIOpenDeallocatorCallback, &mac);
|
||||
size_t size_in_bytes;
|
||||
miopenStatus_t status = miopenStatusSuccess;
|
||||
// An algorithm has been specified.
|
||||
*algorithm_desc = *input_algo_desc;
|
||||
|
||||
switch (kind) {
|
||||
case dnn::ConvolutionKind::FORWARD: {
|
||||
status = wrap::miopenConvolutionForwardGetWorkSpaceSize(
|
||||
miopen.handle(), /*filterDesc=*/filter.handle(),
|
||||
/*srcDesc=*/input_nd.handle(), /*convDesc=*/conv.handle(),
|
||||
/*destDesc=*/output_nd.handle(), /*sizeInBytes=*/&size_in_bytes);
|
||||
break;
|
||||
const uint64_t solution_id = algorithm_desc->algo_id();
|
||||
|
||||
size_t scratch_memory_size = 0;
|
||||
|
||||
switch (kind) {
|
||||
case dnn::ConvolutionKind::FORWARD: {
|
||||
auto status = wrap::miopenConvolutionForwardGetSolutionWorkspaceSize(
|
||||
miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
|
||||
output_nd.handle(), solution_id, &scratch_memory_size);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
return port::InternalError(absl::StrCat(
|
||||
"call to miopenConvolutionForwardGetSolutionWorkspaceSize "
|
||||
"failed: ",
|
||||
ToString(status)));
|
||||
}
|
||||
case dnn::ConvolutionKind::BACKWARD_DATA: {
|
||||
status = wrap::miopenConvolutionBackwardDataGetWorkSpaceSize(
|
||||
miopen.handle(), /*diffDesc=*/output_nd.handle(),
|
||||
/*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
|
||||
/*gradDesc=*/input_nd.handle(), /*sizeInBytes=*/&size_in_bytes);
|
||||
break;
|
||||
}
|
||||
case dnn::ConvolutionKind::BACKWARD_FILTER: {
|
||||
status = wrap::miopenConvolutionBackwardWeightsGetWorkSpaceSize(
|
||||
miopen.handle(), /*diffDesc=*/output_nd.handle(),
|
||||
/*srcDesc=*/input_nd.handle(), /*convDesc=*/conv.handle(),
|
||||
/*gradDesc=*/filter.handle(), /*sizeInBytes=*/&size_in_bytes);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return port::InternalError(absl::StrCat("Unexpected convolution kind ",
|
||||
static_cast<int>(kind)));
|
||||
break;
|
||||
}
|
||||
|
||||
if (status == miopenStatusSuccess && size_in_bytes != 0) {
|
||||
auto allocated = scratch_allocator->AllocateBytes(size_in_bytes);
|
||||
if (allocated.ok()) {
|
||||
scratch_memory_temp = allocated.ValueOrDie();
|
||||
case dnn::ConvolutionKind::BACKWARD_DATA: {
|
||||
auto status = wrap::miopenConvolutionBackwardDataGetSolutionWorkspaceSize(
|
||||
miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
|
||||
input_nd.handle(), solution_id, &scratch_memory_size);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
return port::InternalError(absl::StrCat(
|
||||
"call to miopenConvolutionabckwardDataGetSolutionWorkspaceSize "
|
||||
"failed: ",
|
||||
ToString(status)));
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
miopenConvAlgoPerf_t preference;
|
||||
int returnedAlgoCount;
|
||||
case dnn::ConvolutionKind::BACKWARD_FILTER: {
|
||||
auto status =
|
||||
wrap::miopenConvolutionBackwardWeightsGetSolutionWorkspaceSize(
|
||||
miopen.handle(), output_nd.handle(), input_nd.handle(),
|
||||
conv.handle(), filter.handle(), solution_id,
|
||||
&scratch_memory_size);
|
||||
|
||||
switch (kind) {
|
||||
case dnn::ConvolutionKind::FORWARD: {
|
||||
auto status = wrap::miopenFindConvolutionForwardAlgorithm(
|
||||
miopen.handle(), input_nd.handle(), input_data.opaque(),
|
||||
filter.handle(), filter_data.opaque(), conv.handle(),
|
||||
output_nd.handle(), output_data.opaque(),
|
||||
/*requestAlgoCount=*/1, &returnedAlgoCount,
|
||||
/*preference=*/&preference,
|
||||
/*workspace*/ scratch_memory_temp.opaque(),
|
||||
/*WorkSpaceSize*/ scratch_memory_temp.size(),
|
||||
/*exhaustiveSearch*/ false);
|
||||
CHECK_EQ(status, miopenStatusSuccess) << "Unable to find a suitable "
|
||||
"algorithm for doing forward "
|
||||
"convolution";
|
||||
*algorithm_desc = dnn::AlgorithmDesc(preference.fwd_algo, false);
|
||||
break;
|
||||
if (status != miopenStatusSuccess) {
|
||||
return port::InternalError(absl::StrCat(
|
||||
"call to miopenConvolutionabckwardWeightsGetSolutionWorkspaceSize "
|
||||
"failed: ",
|
||||
ToString(status)));
|
||||
}
|
||||
case dnn::ConvolutionKind::BACKWARD_DATA: {
|
||||
auto status = wrap::miopenFindConvolutionBackwardDataAlgorithm(
|
||||
miopen.handle(),
|
||||
/*diffDesc=*/output_nd.handle(), output_data.opaque(),
|
||||
/*filterDesc=*/filter.handle(), filter_data.opaque(),
|
||||
/*convDesc=*/conv.handle(),
|
||||
/*gradDesc=*/input_nd.handle(), input_data.opaque(),
|
||||
/*requestCount=*/1, /*returnedAlgoCount=*/&returnedAlgoCount,
|
||||
/*preference=*/&preference,
|
||||
/*WorkSpace=*/scratch_memory_temp.opaque(),
|
||||
/*WorkSpaceSize=*/scratch_memory_temp.size(),
|
||||
/*exhaustiveSearch=*/false);
|
||||
CHECK_EQ(status, miopenStatusSuccess) << "Unable to find a suitable "
|
||||
"algorithm for doing backward "
|
||||
"data convolution";
|
||||
*algorithm_desc = dnn::AlgorithmDesc(preference.bwd_data_algo, false);
|
||||
break;
|
||||
}
|
||||
case dnn::ConvolutionKind::BACKWARD_FILTER: {
|
||||
auto status = wrap::miopenFindConvolutionBackwardWeightsAlgorithm(
|
||||
miopen.handle(),
|
||||
/*diffDesc=*/output_nd.handle(), output_data.opaque(),
|
||||
/*srcDesc=*/input_nd.handle(), input_data.opaque(),
|
||||
/*convDesc=*/conv.handle(),
|
||||
/*gradDesc=*/filter.handle(), filter_data.opaque(),
|
||||
/*requestAlgoCount=*/1, /*returnedAlgoCount=*/&returnedAlgoCount,
|
||||
/*preference=*/&preference,
|
||||
/*WorkSpace=*/scratch_memory_temp.opaque(),
|
||||
/*WorkSpaceSize=*/scratch_memory_temp.size(),
|
||||
/*exhaustiveSearch=*/false);
|
||||
CHECK_EQ(status, miopenStatusSuccess) << "Unable to find a suitable "
|
||||
"algorithm for doing backward "
|
||||
"filter convolution";
|
||||
*algorithm_desc =
|
||||
dnn::AlgorithmDesc(preference.bwd_weights_algo, false);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return port::InternalError(absl::StrCat("Unexpected convolution kind ",
|
||||
static_cast<int>(kind)));
|
||||
break;
|
||||
}
|
||||
|
||||
// Restore default allocator, note mac is stack temp
|
||||
wrap::miopenSetAllocator(miopen.handle(), nullptr, nullptr, nullptr);
|
||||
|
||||
scratch_memory_size = preference.memory;
|
||||
} else {
|
||||
// An algorithm has been specified.
|
||||
*algorithm_desc = *algo_desc;
|
||||
scratch_memory_size = *(algorithm_config.scratch_size());
|
||||
default: {
|
||||
return port::InternalError(
|
||||
absl::StrCat("Unexpected convolution kind ", static_cast<int>(kind)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(2) << "miopen...GetSolutionWorkspaceSize returned "
|
||||
<< scratch_memory_size << " for solution_id " << solution_id;
|
||||
|
||||
// allocate scratch memory
|
||||
if (scratch_memory_size != 0) {
|
||||
if (scratch_allocator == nullptr) {
|
||||
@ -2745,12 +2737,18 @@ port::Status MIOpenSupport::DoPrepareForConvolution(
|
||||
"needed"));
|
||||
}
|
||||
auto allocated = scratch_allocator->AllocateBytes(scratch_memory_size);
|
||||
if (!allocated.ok()) {
|
||||
return port::InternalError(absl::StrCat(
|
||||
"Failed to allocate scratch memory of size: ", scratch_memory_size));
|
||||
}
|
||||
if (allocated.ok()) {
|
||||
*scratch_memory = allocated.ValueOrDie();
|
||||
} else {
|
||||
LOG(ERROR)
|
||||
<< "Failed to allocate scratch memory - "
|
||||
<< allocated.status().error_message() << "\n"
|
||||
<< "\tYou can set the env var TF_CUDNN_WORKSPACE_LIMIT_IN_MB to a "
|
||||
"larger number (e.g. 8192) to increase the max memory limit.\n"
|
||||
<< "\tIncreasing the max memory limit might help resolve this "
|
||||
"error";
|
||||
return port::InternalError(absl::StrCat(
|
||||
"Failed to allocate scratch memory of size: ", scratch_memory_size));
|
||||
}
|
||||
}
|
||||
|
||||
@ -2846,20 +2844,17 @@ port::Status MIOpenSupport::DoConvolve(
|
||||
}
|
||||
}
|
||||
|
||||
const uint64_t solution_id = algorithm_desc.algo_id();
|
||||
|
||||
miopenStatus_t status = miopenStatusSuccess;
|
||||
switch (kind) {
|
||||
case dnn::ConvolutionKind::FORWARD: {
|
||||
status = wrap::miopenConvolutionForward(
|
||||
miopen.handle(),
|
||||
/*alpha=*/&alpha, /*srcDesc=*/input_nd.handle(),
|
||||
/*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
|
||||
/*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
|
||||
/*algo=*/
|
||||
static_cast<miopenConvFwdAlgorithm_t>(algorithm_desc.algo_id()),
|
||||
/*beta=*/&beta, /*destDesc=*/output_nd.handle(),
|
||||
/*destData=*/output_data.opaque(),
|
||||
/*workSpace=*/scratch_memory.opaque(),
|
||||
/*workSpaceSizeInBytes=*/scratch_memory.size());
|
||||
status = wrap::miopenConvolutionForwardImmediate(
|
||||
miopen.handle(), filter.handle(), filter_data.opaque(),
|
||||
input_nd.handle(), input_data.opaque(), conv.handle(),
|
||||
output_nd.handle(), output_data.opaque(), scratch_memory.opaque(),
|
||||
scratch_memory.size(), solution_id);
|
||||
|
||||
break;
|
||||
}
|
||||
case dnn::ConvolutionKind::BACKWARD_DATA: {
|
||||
@ -2871,21 +2866,11 @@ port::Status MIOpenSupport::DoConvolve(
|
||||
stream, miopen.handle(), ToMIOpenDataType(element_type),
|
||||
&output_back_descriptor, output_data, &transform_scratch);
|
||||
|
||||
status = wrap::miopenConvolutionBackwardData(
|
||||
miopen.handle(),
|
||||
/*alpha=*/&alpha,
|
||||
/*diffDesc=*/output_nd.handle(),
|
||||
/*diffData=*/output_data.opaque(),
|
||||
/*filterDesc=*/filter.handle(),
|
||||
/*filterData=*/filter_data.opaque(),
|
||||
/*convDesc=*/conv.handle(),
|
||||
/*algo=*/
|
||||
static_cast<miopenConvBwdDataAlgorithm_t>(algorithm_desc.algo_id()),
|
||||
/*beta=*/&beta,
|
||||
/*gradDesc=*/input_nd.handle(),
|
||||
/*gradData=*/input_data.opaque(),
|
||||
/*workSpace=*/scratch_memory.opaque(),
|
||||
/*workSpaceSizeInBytes=*/scratch_memory.size());
|
||||
status = wrap::miopenConvolutionBackwardDataImmediate(
|
||||
miopen.handle(), output_nd.handle(), output_data.opaque(),
|
||||
filter.handle(), filter_data.opaque(), conv.handle(),
|
||||
input_nd.handle(), input_data.opaque(), scratch_memory.opaque(),
|
||||
scratch_memory.size(), solution_id);
|
||||
break;
|
||||
}
|
||||
case dnn::ConvolutionKind::BACKWARD_FILTER: {
|
||||
@ -2897,22 +2882,11 @@ port::Status MIOpenSupport::DoConvolve(
|
||||
stream, miopen.handle(), ToMIOpenDataType(element_type),
|
||||
&output_back_descriptor, output_data, &transform_scratch);
|
||||
|
||||
status = wrap::miopenConvolutionBackwardWeights(
|
||||
miopen.handle(),
|
||||
/*alpha=*/&alpha,
|
||||
/*diffDesc=*/output_nd.handle(),
|
||||
/*diffData=*/output_data.opaque(),
|
||||
/*srcDesc=*/input_nd.handle(),
|
||||
/*srcData=*/input_data.opaque(),
|
||||
/*convDesc=*/conv.handle(),
|
||||
/*algo=*/
|
||||
static_cast<miopenConvBwdWeightsAlgorithm_t>(
|
||||
algorithm_desc.algo_id()),
|
||||
/*beta=*/&beta,
|
||||
/*gradDesc=*/filter.handle(),
|
||||
/*gradData=*/filter_data.opaque(),
|
||||
/*workSpace=*/scratch_memory.opaque(),
|
||||
/*workSpaceSizeInBytes=*/scratch_memory.size());
|
||||
status = wrap::miopenConvolutionBackwardWeightsImmediate(
|
||||
miopen.handle(), output_nd.handle(), output_data.opaque(),
|
||||
input_nd.handle(), input_data.opaque(), conv.handle(),
|
||||
filter.handle(), filter_data.opaque(), scratch_memory.opaque(),
|
||||
scratch_memory.size(), solution_id);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
@ -2958,6 +2932,312 @@ bool MIOpenSupport::GetConvolveAlgorithms(
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MIOpenSupport::GetMIOpenConvolveAlgorithms(
|
||||
dnn::ConvolutionKind kind, Stream* stream, dnn::DataType element_type,
|
||||
const dnn::BatchDescriptor& input_descriptor,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
std::vector<dnn::ProfileResult>* out_algorithms) {
|
||||
auto miopen = miopen_->GetHandle(parent_, stream);
|
||||
|
||||
ScopedTensorDescriptor input_nd{input_descriptor,
|
||||
ToMIOpenDataType(element_type)};
|
||||
ScopedTensorDescriptor output_nd{output_descriptor,
|
||||
ToMIOpenDataType(element_type)};
|
||||
ScopedFilterDescriptor filter{filter_descriptor, input_descriptor,
|
||||
ToMIOpenDataType(element_type)};
|
||||
ScopedConvolutionDescriptor conv{convolution_descriptor,
|
||||
ToMIOpenDataType(element_type)};
|
||||
|
||||
// First determine the number of algorityhms available
|
||||
size_t maxSolutionCount = 0;
|
||||
|
||||
switch (kind) {
|
||||
case dnn::ConvolutionKind::FORWARD: {
|
||||
auto status = wrap::miopenConvolutionForwardGetSolutionCount(
|
||||
miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
|
||||
output_nd.handle(), &maxSolutionCount);
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL)
|
||||
<< "call to miopenConvolutionForwardGetSolutionCount failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case dnn::ConvolutionKind::BACKWARD_DATA: {
|
||||
auto status = wrap::miopenConvolutionBackwardDataGetSolutionCount(
|
||||
miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
|
||||
input_nd.handle(), &maxSolutionCount);
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL)
|
||||
<< "call to miopenConvolutionBackwardDataGetSolutionCount failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
case dnn::ConvolutionKind::BACKWARD_FILTER: {
|
||||
auto status = wrap::miopenConvolutionBackwardWeightsGetSolutionCount(
|
||||
miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
|
||||
filter.handle(), &maxSolutionCount);
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL)
|
||||
<< "call to miopenConvolutionBackwardWeightsGetSolutionCount "
|
||||
"failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(kImmediateModeVlogLevel)
|
||||
<< "Number of conv solutions max: " << maxSolutionCount;
|
||||
|
||||
// if the env var TF_ROCM_MIMIC_FIND_MODE is set, determine the best solution
|
||||
// as per the "runtime" information for each solution (returned by the prior
|
||||
// call to the *GetSolution api), and then return only the best solution
|
||||
// The idea here is to mimic the old "find" mode, in which we relied upon
|
||||
// the miopen api to determine the best solution, and use that solution
|
||||
// without doing any further measurement in the TF layer
|
||||
bool mimic_find_mode = false;
|
||||
tensorflow::ReadBoolFromEnvVar("TF_ROCM_MIMIC_FIND_MODE", false,
|
||||
&mimic_find_mode);
|
||||
|
||||
size_t solutionCount = 0;
|
||||
std::unique_ptr<miopenConvSolution_t[]> solutions(
|
||||
new miopenConvSolution_t[maxSolutionCount]);
|
||||
|
||||
switch (kind) {
|
||||
case dnn::ConvolutionKind::FORWARD: {
|
||||
auto status = wrap::miopenConvolutionForwardGetSolution(
|
||||
miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
|
||||
output_nd.handle(), maxSolutionCount, &solutionCount,
|
||||
solutions.get());
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL) << "call to miopenConvolutionForwardGetSolution failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
VLOG(kImmediateModeVlogLevel)
|
||||
<< "Number of conv solutions actual: " << solutionCount;
|
||||
|
||||
if (mimic_find_mode) {
|
||||
miopenConvSolution_t best_solution = solutions[0];
|
||||
|
||||
for (int i = 1; i < solutionCount; i++) {
|
||||
miopenConvSolution_t solution = solutions[i];
|
||||
if (solution.time < best_solution.time) {
|
||||
best_solution = solution;
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(kImmediateModeVlogLevel)
|
||||
<< "Best Solution (id, algo) = " << best_solution.solution_id
|
||||
<< ", " << ToString(best_solution.algorithm);
|
||||
|
||||
status = wrap::miopenConvolutionForwardCompileSolution(
|
||||
miopen.handle(), filter.handle(), input_nd.handle(), conv.handle(),
|
||||
output_nd.handle(), best_solution.solution_id);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL) << "call to miopenConvolutionForwardCompileSolution "
|
||||
"failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
out_algorithms->emplace_back(
|
||||
GetProfileResultFromConvSolution(best_solution));
|
||||
|
||||
} else {
|
||||
for (int i = 0; i < solutionCount; i++) {
|
||||
miopenConvSolution_t solution = solutions[i];
|
||||
|
||||
VLOG(kImmediateModeVlogLevel)
|
||||
<< "solution " << i
|
||||
<< " (time, mem, id, algo) = " << solution.time << ", "
|
||||
<< solution.workspace_size << ", " << solution.solution_id << ", "
|
||||
<< ToString(solution.algorithm);
|
||||
|
||||
status = wrap::miopenConvolutionForwardCompileSolution(
|
||||
miopen.handle(), filter.handle(), input_nd.handle(),
|
||||
conv.handle(), output_nd.handle(), solution.solution_id);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL)
|
||||
<< "call to miopenConvolutionForwardCompileSolution failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
out_algorithms->emplace_back(
|
||||
GetProfileResultFromConvSolution(solution));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
case dnn::ConvolutionKind::BACKWARD_DATA: {
|
||||
auto status = wrap::miopenConvolutionBackwardDataGetSolution(
|
||||
miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
|
||||
input_nd.handle(), maxSolutionCount, &solutionCount, solutions.get());
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL)
|
||||
<< "call to miopenConvolutionBackwardDataGetSolution failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
VLOG(kImmediateModeVlogLevel)
|
||||
<< "Number of conv solutions actual: " << solutionCount;
|
||||
|
||||
if (mimic_find_mode) {
|
||||
miopenConvSolution_t best_solution = solutions[0];
|
||||
|
||||
for (int i = 1; i < solutionCount; i++) {
|
||||
miopenConvSolution_t solution = solutions[i];
|
||||
if (solution.time < best_solution.time) {
|
||||
best_solution = solution;
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(kImmediateModeVlogLevel)
|
||||
<< "Best Solution (id, algo) = " << best_solution.solution_id
|
||||
<< ", " << ToString(best_solution.algorithm);
|
||||
|
||||
status = wrap::miopenConvolutionBackwardDataCompileSolution(
|
||||
miopen.handle(), output_nd.handle(), filter.handle(), conv.handle(),
|
||||
input_nd.handle(), best_solution.solution_id);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL) << "call to miopenConvolutionBackwardDataCompileSolution "
|
||||
"failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
out_algorithms->emplace_back(
|
||||
GetProfileResultFromConvSolution(best_solution));
|
||||
|
||||
} else {
|
||||
for (int i = 0; i < solutionCount; i++) {
|
||||
miopenConvSolution_t solution = solutions[i];
|
||||
|
||||
VLOG(kImmediateModeVlogLevel)
|
||||
<< "solution " << i
|
||||
<< " (time, mem, id, algo) = " << solution.time << ", "
|
||||
<< solution.workspace_size << ", " << solution.solution_id << ", "
|
||||
<< ToString(solution.algorithm);
|
||||
|
||||
status = wrap::miopenConvolutionBackwardDataCompileSolution(
|
||||
miopen.handle(), output_nd.handle(), filter.handle(),
|
||||
conv.handle(), input_nd.handle(), solution.solution_id);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL)
|
||||
<< " call to miopenConvolutionBackwardDataCompileSolution "
|
||||
"failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
out_algorithms->emplace_back(
|
||||
GetProfileResultFromConvSolution(solution));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
case dnn::ConvolutionKind::BACKWARD_FILTER: {
|
||||
auto status = wrap::miopenConvolutionBackwardWeightsGetSolution(
|
||||
miopen.handle(), output_nd.handle(), input_nd.handle(), conv.handle(),
|
||||
filter.handle(), maxSolutionCount, &solutionCount, solutions.get());
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL)
|
||||
<< "call to miopenConvolutionBackwardWeightsGetSolution failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
VLOG(kImmediateModeVlogLevel)
|
||||
<< "Number of conv solutions actual: " << solutionCount;
|
||||
|
||||
if (mimic_find_mode) {
|
||||
miopenConvSolution_t best_solution = solutions[0];
|
||||
|
||||
for (int i = 1; i < solutionCount; i++) {
|
||||
miopenConvSolution_t solution = solutions[i];
|
||||
if (solution.time < best_solution.time) {
|
||||
best_solution = solution;
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(kImmediateModeVlogLevel)
|
||||
<< "Best Solution (id, algo) = " << best_solution.solution_id
|
||||
<< ", " << ToString(best_solution.algorithm);
|
||||
|
||||
status = wrap::miopenConvolutionBackwardWeightsCompileSolution(
|
||||
miopen.handle(), output_nd.handle(), input_nd.handle(),
|
||||
conv.handle(), filter.handle(), best_solution.solution_id);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL)
|
||||
<< "call to miopenConvolutionBackwardWeightsCompileSolution "
|
||||
"failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
out_algorithms->emplace_back(
|
||||
GetProfileResultFromConvSolution(best_solution));
|
||||
|
||||
} else {
|
||||
for (int i = 0; i < solutionCount; i++) {
|
||||
miopenConvSolution_t solution = solutions[i];
|
||||
|
||||
VLOG(kImmediateModeVlogLevel)
|
||||
<< "solution " << i
|
||||
<< " (time, mem, id, algo) = " << solution.time << ", "
|
||||
<< solution.workspace_size << ", " << solution.solution_id << ", "
|
||||
<< ToString(solution.algorithm);
|
||||
|
||||
status = wrap::miopenConvolutionBackwardWeightsCompileSolution(
|
||||
miopen.handle(), output_nd.handle(), input_nd.handle(),
|
||||
conv.handle(), filter.handle(), solution.solution_id);
|
||||
|
||||
if (status != miopenStatusSuccess) {
|
||||
LOG(FATAL)
|
||||
<< "call to miopenConvolutionBackwardWeightsCompileSolution "
|
||||
"failed: "
|
||||
<< ToString(status);
|
||||
return false;
|
||||
}
|
||||
|
||||
out_algorithms->emplace_back(
|
||||
GetProfileResultFromConvSolution(solution));
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
LOG(FATAL) << "Unexpected convolution kind " << static_cast<int>(kind);
|
||||
return false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool MIOpenSupport::GetRnnAlgorithms(
|
||||
std::vector<dnn::AlgorithmDesc>* out_algorithms) {
|
||||
// ROCM TODO: implement this with proper MIOpen API
|
||||
|
@ -195,6 +195,14 @@ class MIOpenSupport : public dnn::DnnSupport {
|
||||
bool with_winograd_nonfused, int cc_major, int cc_minor,
|
||||
std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
|
||||
|
||||
bool GetMIOpenConvolveAlgorithms(
|
||||
dnn::ConvolutionKind kind, Stream* stream, dnn::DataType element_type,
|
||||
const dnn::BatchDescriptor& input_descriptor,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor,
|
||||
const dnn::BatchDescriptor& output_descriptor,
|
||||
std::vector<dnn::ProfileResult>* out_algorithms) override;
|
||||
|
||||
bool GetRnnAlgorithms(
|
||||
std::vector<dnn::AlgorithmDesc>* out_algorithms) override;
|
||||
|
||||
|
@ -290,6 +290,22 @@ bool StreamExecutor::GetConvolveAlgorithms(
|
||||
cc_minor, out_algorithms);
|
||||
}
|
||||
|
||||
bool StreamExecutor::GetMIOpenConvolveAlgorithms(
|
||||
dnn::ConvolutionKind kind, Stream *stream, dnn::DataType element_type,
|
||||
const dnn::BatchDescriptor &input_descriptor,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
std::vector<dnn::ProfileResult> *out_algorithms) {
|
||||
dnn::DnnSupport *dnn_support = AsDnn();
|
||||
if (!dnn_support) {
|
||||
return false;
|
||||
}
|
||||
return dnn_support->GetMIOpenConvolveAlgorithms(
|
||||
kind, stream, element_type, input_descriptor, filter_descriptor,
|
||||
convolution_descriptor, output_descriptor, out_algorithms);
|
||||
}
|
||||
|
||||
bool StreamExecutor::GetRnnAlgorithms(
|
||||
std::vector<dnn::AlgorithmDesc> *out_algorithms) {
|
||||
dnn::DnnSupport *dnn_support = AsDnn();
|
||||
|
@ -372,6 +372,16 @@ class StreamExecutor {
|
||||
bool GetConvolveAlgorithms(bool with_winograd_nonfused,
|
||||
std::vector<dnn::AlgorithmDesc> *out_algorithms);
|
||||
|
||||
// Returns the list of supported algorithms for the forward convolution
|
||||
// operation.
|
||||
bool GetMIOpenConvolveAlgorithms(
|
||||
dnn::ConvolutionKind kind, Stream *stream, dnn::DataType element_type,
|
||||
const dnn::BatchDescriptor &input_descriptor,
|
||||
const dnn::FilterDescriptor &filter_descriptor,
|
||||
const dnn::ConvolutionDescriptor &convolution_descriptor,
|
||||
const dnn::BatchDescriptor &output_descriptor,
|
||||
std::vector<dnn::ProfileResult> *out_algorithms);
|
||||
|
||||
// Returns the list of supported algorithms for rnn operation.
|
||||
bool GetRnnAlgorithms(std::vector<dnn::AlgorithmDesc> *out_algorithms);
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user