Roll-forward:
Log convolutions during Tensorflow GPU conv autotuning. Also removed the same functionality from StreamExecutor. We decided to move the loggings from SE to TF and XLA for several reasons: * Proto formats already exist in TF and XLA that are suitable for logging. No need to create a third proto. * In TF and XLA autotuning stage, we also do/plan to do correctness checking. We want to log the checking results. * We are considering simplifying SE, so we prefer to keep SE simple for now. The original patch fails on Windows because the Windows linker crashes if it links an object file generated from an empty source file. In this CL, such empty source file is gpu_utils.cc, in the case where everything is #ifdef'ed out by GOOGLE_CUDA. To work-around it, simply don't compile such empty file at all for non-CUDA builds. PiperOrigin-RevId: 237284443
This commit is contained in:
parent
4d1c0202aa
commit
94be8f012a
@ -68,6 +68,8 @@ tf_kernel_library(
|
||||
prefix = "fused_conv2d_bias_activation_op",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:autotuning_proto_cc",
|
||||
"//tensorflow/core:conv_autotuning_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
@ -92,6 +94,8 @@ tf_custom_op_library(
|
||||
"ops/fused_conv2d_bias_activation_op.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:autotuning_proto_cc",
|
||||
"//tensorflow/core:conv_autotuning_proto_cc",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core/kernels:bounds_check_lib",
|
||||
"//tensorflow/core/kernels:conv_2d_hdrs",
|
||||
|
@ -34,9 +34,16 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "google/protobuf/duration.pb.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "cuda/include/cudnn.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/protobuf/conv_autotuning.pb.h"
|
||||
#include "tensorflow/core/util/activation_mode.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
@ -252,6 +259,131 @@ class FusedConv2DBiasActivationOp : public OpKernel {
|
||||
#if GOOGLE_CUDA
|
||||
namespace dnn = se::dnn;
|
||||
|
||||
// Several functions are copyed over from tensorflow/core/kernels/gpu_utils,
|
||||
// since this file may be compiled down to a tf_custom_op_library .so file,
|
||||
// which can't depend on basic dependencies like tensorflow/core:lib. Instead,
|
||||
// the code has to depend on whatever is the same in libtensorflow_framework.so.
|
||||
//
|
||||
// In theory, we can lift the dependencies of gpu_utils by turning it into a
|
||||
// template library that provides duck typing, but I think duplication is the
|
||||
// lesser of two evils.
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
|
||||
tensorflow::CudnnVersion cudnn_version;
|
||||
if (auto* dnn = stream_executor->AsDnn()) {
|
||||
se::port::StatusOr<se::dnn::VersionInfo> version_or = dnn->GetVersion();
|
||||
if (version_or.ok()) {
|
||||
const auto& version = version_or.ValueOrDie();
|
||||
cudnn_version.set_major(version.major_version());
|
||||
cudnn_version.set_minor(version.minor_version());
|
||||
cudnn_version.set_patch(version.patch());
|
||||
}
|
||||
}
|
||||
return cudnn_version;
|
||||
}
|
||||
|
||||
// Converts an absl::Duration to a google::protobuf::Duration.
|
||||
inline google::protobuf::Duration ToDurationProto(absl::Duration duration) {
|
||||
google::protobuf::Duration proto;
|
||||
proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration));
|
||||
proto.set_nanos(
|
||||
absl::IDivDuration(duration, absl::Nanoseconds(1), &duration));
|
||||
return proto;
|
||||
}
|
||||
|
||||
// Converts a google::protobuf::Duration to an absl::Duration.
|
||||
inline absl::Duration FromDurationProto(google::protobuf::Duration proto) {
|
||||
return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos());
|
||||
}
|
||||
|
||||
tensorflow::ComputeCapability GetComputeCapability(
|
||||
se::StreamExecutor* stream_executor) {
|
||||
tensorflow::ComputeCapability cc;
|
||||
int cc_major, cc_minor;
|
||||
stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
||||
&cc_minor);
|
||||
cc.set_major(cc_major);
|
||||
cc.set_minor(cc_minor);
|
||||
return cc;
|
||||
}
|
||||
|
||||
void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input,
|
||||
const Tensor& filter, const Tensor& output,
|
||||
const Tensor& bias, const Tensor* side_input,
|
||||
se::StreamExecutor* stream_exec,
|
||||
absl::Span<const AutotuneResult> results) {
|
||||
AutotuningLog log;
|
||||
ConvNodeDef instr;
|
||||
*instr.mutable_conv() = node;
|
||||
input.shape().AsProto(instr.mutable_input()->mutable_tensor_shape());
|
||||
instr.mutable_input()->set_dtype(input.dtype());
|
||||
filter.shape().AsProto(instr.mutable_filter()->mutable_tensor_shape());
|
||||
instr.mutable_filter()->set_dtype(filter.dtype());
|
||||
output.shape().AsProto(instr.mutable_output()->mutable_tensor_shape());
|
||||
instr.mutable_output()->set_dtype(output.dtype());
|
||||
bias.shape().AsProto(instr.mutable_bias()->mutable_tensor_shape());
|
||||
instr.mutable_bias()->set_dtype(bias.dtype());
|
||||
if (side_input) {
|
||||
side_input->shape().AsProto(
|
||||
instr.mutable_side_input()->mutable_tensor_shape());
|
||||
instr.mutable_side_input()->set_dtype(side_input->dtype());
|
||||
}
|
||||
log.mutable_instr()->PackFrom(std::move(instr));
|
||||
*log.mutable_cudnn_version() = internal::GetCudnnVersion(stream_exec);
|
||||
*log.mutable_compute_capability() =
|
||||
internal::GetComputeCapability(stream_exec);
|
||||
for (const auto& result : results) {
|
||||
*log.add_results() = result;
|
||||
}
|
||||
Logger::Singleton()->LogProto(log);
|
||||
}
|
||||
|
||||
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
se::dnn::AlgorithmConfig* algo) {
|
||||
// For the "!xhs.has_success()" below, this is because we want successful ones
|
||||
// to order first, therefore they need a smaller key per "min_element".
|
||||
const AutotuneResult* best_result = std::min_element(
|
||||
results.begin(), results.end(),
|
||||
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
||||
return std::make_tuple(
|
||||
!lhs.has_success(),
|
||||
internal::FromDurationProto(lhs.success().run_time())) <
|
||||
std::make_tuple(
|
||||
!rhs.has_success(),
|
||||
internal::FromDurationProto(rhs.success().run_time()));
|
||||
});
|
||||
|
||||
const AutotuneResult* best_result_no_scratch = std::min_element(
|
||||
results.begin(), results.end(),
|
||||
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
||||
return std::make_tuple(
|
||||
!lhs.has_success(), lhs.success().scratch_bytes(),
|
||||
internal::FromDurationProto(lhs.success().run_time())) <
|
||||
std::make_tuple(
|
||||
!rhs.has_success(), rhs.success().scratch_bytes(),
|
||||
internal::FromDurationProto(rhs.success().run_time()));
|
||||
});
|
||||
|
||||
if (best_result == results.end() || !best_result->has_success()) {
|
||||
return errors::NotFound("No algorithm worked!");
|
||||
}
|
||||
algo->set_algorithm({best_result->conv().algorithm(),
|
||||
best_result->conv().tensor_ops_enabled()});
|
||||
if (best_result_no_scratch != results.end() &&
|
||||
best_result_no_scratch->has_success() &&
|
||||
best_result_no_scratch->success().scratch_bytes() == 0) {
|
||||
algo->set_algorithm_no_scratch(
|
||||
{best_result_no_scratch->conv().algorithm(),
|
||||
best_result_no_scratch->conv().tensor_ops_enabled()});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
|
||||
// A dummy type to group forward convolution autotune results together.
|
||||
struct ConvBiasActivationAutoTuneGroup {
|
||||
static string name() { return "ConvBiasActivation"; }
|
||||
@ -579,8 +711,7 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
|
||||
}),
|
||||
algorithms.end());
|
||||
}
|
||||
dnn::ProfileResult best_result;
|
||||
dnn::ProfileResult best_result_no_scratch;
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
@ -597,28 +728,24 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
}
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.mutable_success()->set_scratch_bytes(
|
||||
scratch_allocator.TotalByteSize());
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
internal::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(ctx,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
internal::LogFusedConvAutotuneResults(ctx->op_kernel().def(), *conv_input,
|
||||
*filter, *output, bias, side_input,
|
||||
stream->parent(), results);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, internal::BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
AutoTuneConvBiasActivation::GetInstance()->Insert(fused_conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
|
@ -2178,6 +2178,18 @@ tf_proto_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "conv_autotuning_proto",
|
||||
srcs = ["protobuf/conv_autotuning.proto"],
|
||||
cc_api_version = 2,
|
||||
default_header = True,
|
||||
protodeps = tf_additional_all_protos(),
|
||||
provide_cc_alias = True,
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "worker_proto",
|
||||
srcs = ["protobuf/worker.proto"],
|
||||
|
@ -46,6 +46,7 @@ load(
|
||||
load("@local_config_sycl//sycl:build_defs.bzl", "if_sycl")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_tests")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
|
||||
load(
|
||||
"//tensorflow/core:platform/default/build_config.bzl",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
@ -269,7 +270,7 @@ tf_kernel_library(
|
||||
deps = [
|
||||
":eigen_helpers",
|
||||
":fill_functor",
|
||||
":gpu_util_hdrs",
|
||||
":gpu_utils",
|
||||
":image_resizer_state",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -463,11 +464,31 @@ cc_library(
|
||||
hdrs = ["conv_ops_gpu.h"],
|
||||
)
|
||||
|
||||
# We keep this target only because some contrib/ targets depend on it. The
|
||||
# reason why the contrib/ targets can't depend on gpu_utils is that, some
|
||||
# of the targets are tf_custom_op_library. tf_custom_op_library forbids the
|
||||
# dependency to tensorflow/core:lib, which gpu_utils certainly depends on.
|
||||
cc_library(
|
||||
name = "gpu_util_hdrs",
|
||||
hdrs = ["gpu_utils.h"],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "gpu_utils",
|
||||
srcs = if_cuda_is_configured(["gpu_utils.cc"]),
|
||||
hdrs = ["gpu_utils.h"],
|
||||
deps = [
|
||||
":gpu_util_hdrs",
|
||||
"//tensorflow/core:autotuning_proto_cc",
|
||||
"//tensorflow/core:conv_autotuning_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:logger",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//tensorflow/core/util/proto:proto_utils",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "ops_util_test",
|
||||
size = "small",
|
||||
@ -1252,7 +1273,7 @@ tf_kernel_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":bounds_check_lib",
|
||||
":gpu_util_hdrs",
|
||||
":gpu_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
@ -3338,7 +3359,7 @@ tf_kernel_library(
|
||||
}),
|
||||
deps = MATH_DEPS + [
|
||||
":eigen_contraction_kernel",
|
||||
":gpu_util_hdrs",
|
||||
":gpu_utils",
|
||||
] + select({
|
||||
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
|
||||
"//conditions:default": [],
|
||||
@ -3760,6 +3781,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/util/proto:proto_utils",
|
||||
] + select({
|
||||
":xsmm_convolutions": [
|
||||
"@libxsmm_archive//:xsmm_avx",
|
||||
|
@ -51,6 +51,8 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace {
|
||||
@ -841,8 +843,7 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
|
||||
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
|
||||
&algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
@ -859,28 +860,23 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
}
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.mutable_success()->set_scratch_bytes(
|
||||
scratch_allocator.TotalByteSize());
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(ctx,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
LogConvAutotuneResults(ctx->op_kernel().def(), transformed_input,
|
||||
pre_transformed_filter_backprop,
|
||||
transformed_out_backprop, stream->parent(), results);
|
||||
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
|
@ -51,6 +51,8 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace {
|
||||
@ -953,8 +955,7 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
|
||||
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
|
||||
&algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
@ -970,28 +971,23 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
}
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.mutable_success()->set_scratch_bytes(
|
||||
scratch_allocator.TotalByteSize());
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(ctx,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
LogConvAutotuneResults(ctx->op_kernel().def(), pre_transformed_in_backprop,
|
||||
transformed_filter, transformed_out_backprop,
|
||||
stream->parent(), results);
|
||||
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
|
@ -55,6 +55,8 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
@ -855,8 +857,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
errors::Unknown("Failed to get convolution algorithm. This is probably "
|
||||
"because cuDNN failed to initialize, so try looking to "
|
||||
"see if a warning log message was printed above."));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
@ -871,30 +872,22 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
}
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.mutable_success()->set_scratch_bytes(
|
||||
scratch_allocator.TotalByteSize());
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO(yangzihao): refactor the profile result checking code into a common
|
||||
// utility function.
|
||||
OP_REQUIRES(ctx,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
LogConvAutotuneResults(ctx->op_kernel().def(), input, transformed_filter,
|
||||
transformed_output, stream->parent(), results);
|
||||
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
AutoTuneConv::GetInstance()->Insert(conv_parameters, algorithm_config);
|
||||
}
|
||||
|
||||
|
@ -34,6 +34,8 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
using stream_executor::dnn::DimIndex;
|
||||
#endif
|
||||
|
||||
@ -445,8 +447,7 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
"because cuDNN failed to initialize, so try looking to "
|
||||
"see if a warning log message was printed above."));
|
||||
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
@ -461,28 +462,22 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
}
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.mutable_success()->set_scratch_bytes(
|
||||
scratch_allocator.TotalByteSize());
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(ctx,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
LogConvAutotuneResults(ctx->op_kernel().def(), input, filter, *output,
|
||||
stream->parent(), results);
|
||||
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
AutoTuneConv3d::GetInstance()->Insert(conv_parameters, algorithm_config);
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This include can't be in the conv_ops_fused_impl.h headers. See b/62899350.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_fused_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This include can't be in the conv_ops_fused_impl.h headers. See b/62899350.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_fused_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This include can't be in the conv_ops_fused_impl.h headers. See b/62899350.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_fused_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -59,10 +59,13 @@ limitations under the License.
|
||||
#include "cuda/include/cudnn.h"
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class AutotuneResult;
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
@ -497,10 +500,11 @@ inline int64 ConvolveScratchSize() {
|
||||
// algorithms and measuring execution time.
|
||||
// TODO(ezhulenev): Move it to conv_ops_gpu.h and share with conv_ops.cc.
|
||||
template <typename T, typename ConvLaunch>
|
||||
Status FindBestConvolveAlgorithm(const FusedConvParameters& params,
|
||||
const ConvLaunch launch,
|
||||
OpKernelContext* context, se::Stream* stream,
|
||||
se::dnn::AlgorithmConfig* algorithm_config) {
|
||||
Status FindBestConvolveAlgorithm(
|
||||
const FusedConvParameters& params, const ConvLaunch launch,
|
||||
OpKernelContext* context, se::Stream* stream,
|
||||
se::dnn::AlgorithmConfig* algorithm_config,
|
||||
std::vector<tensorflow::AutotuneResult>* results) {
|
||||
// Check if we already have an algorithm selected for the given parameters.
|
||||
if (AutoTuneFusedConv::GetInstance()->Find(params, algorithm_config)) {
|
||||
return Status::OK();
|
||||
@ -517,9 +521,6 @@ Status FindBestConvolveAlgorithm(const FusedConvParameters& params,
|
||||
"see if a warning log message was printed above.");
|
||||
}
|
||||
|
||||
se::dnn::ProfileResult best_result;
|
||||
se::dnn::ProfileResult best_result_no_scratch;
|
||||
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
|
||||
se::dnn::ProfileResult profile_result;
|
||||
@ -529,29 +530,19 @@ Status FindBestConvolveAlgorithm(const FusedConvParameters& params,
|
||||
&profile_result);
|
||||
|
||||
if (cudnn_launch_status && profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
}
|
||||
results->emplace_back();
|
||||
auto& result = results->back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.mutable_success()->set_scratch_bytes(
|
||||
scratch_allocator.TotalByteSize());
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
|
||||
if (!best_result.is_valid() && !best_result_no_scratch.is_valid()) {
|
||||
return errors::NotFound("No algorithm worked!");
|
||||
}
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config->set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config->set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(BestCudnnConvAlgorithm(*results, algorithm_config));
|
||||
AutoTuneFusedConv::GetInstance()->Insert(params, *algorithm_config);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -798,9 +789,14 @@ struct LaunchFusedConv2DOp<GPUDevice, T> {
|
||||
|
||||
se::dnn::AlgorithmConfig algorithm_config;
|
||||
if (cudnn_use_autotune) {
|
||||
OP_REQUIRES_OK(context, FindBestConvolveAlgorithm<T>(
|
||||
conv_parameters, launch, context, stream,
|
||||
&algorithm_config));
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
auto status =
|
||||
FindBestConvolveAlgorithm<T>(conv_parameters, launch, context, stream,
|
||||
&algorithm_config, &results);
|
||||
LogFusedConvAutotuneResults(context->op_kernel().def(), input,
|
||||
transformed_filter, transformed_output, bias,
|
||||
nullptr, stream->parent(), results);
|
||||
OP_REQUIRES_OK(context, status);
|
||||
}
|
||||
|
||||
DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
|
||||
|
153
tensorflow/core/kernels/gpu_utils.cc
Normal file
153
tensorflow/core/kernels/gpu_utils.cc
Normal file
@ -0,0 +1,153 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include "google/protobuf/any.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/protobuf/conv_autotuning.pb.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
|
||||
tensorflow::CudnnVersion cudnn_version;
|
||||
if (auto* dnn = stream_executor->AsDnn()) {
|
||||
se::port::StatusOr<se::dnn::VersionInfo> version_or = dnn->GetVersion();
|
||||
if (version_or.ok()) {
|
||||
const auto& version = version_or.ValueOrDie();
|
||||
cudnn_version.set_major(version.major_version());
|
||||
cudnn_version.set_minor(version.minor_version());
|
||||
cudnn_version.set_patch(version.patch());
|
||||
}
|
||||
}
|
||||
return cudnn_version;
|
||||
}
|
||||
|
||||
tensorflow::ComputeCapability GetComputeCapability(
|
||||
se::StreamExecutor* stream_executor) {
|
||||
tensorflow::ComputeCapability cc;
|
||||
int cc_major, cc_minor;
|
||||
stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
||||
&cc_minor);
|
||||
cc.set_major(cc_major);
|
||||
cc.set_minor(cc_minor);
|
||||
return cc;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void LogConvAutotuneResults(const NodeDef& node, const Tensor& input,
|
||||
const Tensor& filter, const Tensor& output,
|
||||
se::StreamExecutor* stream_exec,
|
||||
absl::Span<const AutotuneResult> results) {
|
||||
AutotuningLog log;
|
||||
ConvNodeDef instr;
|
||||
*instr.mutable_conv() = node;
|
||||
input.shape().AsProto(instr.mutable_input()->mutable_tensor_shape());
|
||||
instr.mutable_input()->set_dtype(input.dtype());
|
||||
filter.shape().AsProto(instr.mutable_filter()->mutable_tensor_shape());
|
||||
instr.mutable_filter()->set_dtype(filter.dtype());
|
||||
output.shape().AsProto(instr.mutable_output()->mutable_tensor_shape());
|
||||
instr.mutable_output()->set_dtype(output.dtype());
|
||||
log.mutable_instr()->PackFrom(std::move(instr));
|
||||
*log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
|
||||
*log.mutable_compute_capability() = GetComputeCapability(stream_exec);
|
||||
for (const auto& result : results) {
|
||||
*log.add_results() = result;
|
||||
}
|
||||
Logger::Singleton()->LogProto(log);
|
||||
}
|
||||
|
||||
void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input,
|
||||
const Tensor& filter, const Tensor& output,
|
||||
const Tensor& bias, const Tensor* side_input,
|
||||
se::StreamExecutor* stream_exec,
|
||||
absl::Span<const AutotuneResult> results) {
|
||||
AutotuningLog log;
|
||||
ConvNodeDef instr;
|
||||
*instr.mutable_conv() = node;
|
||||
input.shape().AsProto(instr.mutable_input()->mutable_tensor_shape());
|
||||
instr.mutable_input()->set_dtype(input.dtype());
|
||||
filter.shape().AsProto(instr.mutable_filter()->mutable_tensor_shape());
|
||||
instr.mutable_filter()->set_dtype(filter.dtype());
|
||||
output.shape().AsProto(instr.mutable_output()->mutable_tensor_shape());
|
||||
instr.mutable_output()->set_dtype(output.dtype());
|
||||
bias.shape().AsProto(instr.mutable_bias()->mutable_tensor_shape());
|
||||
instr.mutable_bias()->set_dtype(bias.dtype());
|
||||
if (side_input) {
|
||||
side_input->shape().AsProto(
|
||||
instr.mutable_side_input()->mutable_tensor_shape());
|
||||
instr.mutable_side_input()->set_dtype(side_input->dtype());
|
||||
}
|
||||
log.mutable_instr()->PackFrom(std::move(instr));
|
||||
*log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
|
||||
*log.mutable_compute_capability() = GetComputeCapability(stream_exec);
|
||||
for (const auto& result : results) {
|
||||
*log.add_results() = result;
|
||||
}
|
||||
Logger::Singleton()->LogProto(log);
|
||||
}
|
||||
|
||||
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
se::dnn::AlgorithmConfig* algo) {
|
||||
// For the "!xhs.has_success()" below, this is because we want successful ones
|
||||
// to order first, therefore they need a smaller key per "min_element".
|
||||
const AutotuneResult* best_result = std::min_element(
|
||||
results.begin(), results.end(),
|
||||
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
||||
return std::make_tuple(
|
||||
!lhs.has_success(),
|
||||
proto_utils::FromDurationProto(lhs.success().run_time())) <
|
||||
std::make_tuple(
|
||||
!rhs.has_success(),
|
||||
proto_utils::FromDurationProto(rhs.success().run_time()));
|
||||
});
|
||||
|
||||
const AutotuneResult* best_result_no_scratch = std::min_element(
|
||||
results.begin(), results.end(),
|
||||
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
||||
return std::make_tuple(
|
||||
!lhs.has_success(), lhs.success().scratch_bytes(),
|
||||
proto_utils::FromDurationProto(lhs.success().run_time())) <
|
||||
std::make_tuple(
|
||||
!rhs.has_success(), rhs.success().scratch_bytes(),
|
||||
proto_utils::FromDurationProto(rhs.success().run_time()));
|
||||
});
|
||||
|
||||
if (best_result == results.end() || !best_result->has_success()) {
|
||||
return errors::NotFound("No algorithm worked!");
|
||||
}
|
||||
algo->set_algorithm({best_result->conv().algorithm(),
|
||||
best_result->conv().tensor_ops_enabled()});
|
||||
if (best_result_no_scratch != results.end() &&
|
||||
best_result_no_scratch->has_success() &&
|
||||
best_result_no_scratch->success().scratch_bytes() == 0) {
|
||||
algo->set_algorithm_no_scratch(
|
||||
{best_result_no_scratch->conv().algorithm(),
|
||||
best_result_no_scratch->conv().tensor_ops_enabled()});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
@ -20,6 +20,9 @@ limitations under the License.
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
@ -28,6 +31,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class NodeDef;
|
||||
class AutotuneResult;
|
||||
|
||||
template <typename T>
|
||||
inline se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
|
||||
se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
|
||||
@ -156,6 +162,25 @@ class AutoTuneSingleton {
|
||||
}
|
||||
};
|
||||
|
||||
// Logs convolution results to customized back-storage.
|
||||
void LogConvAutotuneResults(const NodeDef& node, const Tensor& input,
|
||||
const Tensor& filter, const Tensor& output,
|
||||
se::StreamExecutor* stream_exec,
|
||||
absl::Span<const AutotuneResult> results);
|
||||
|
||||
// Logs fused convolution results to customized back-storage.
|
||||
void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input,
|
||||
const Tensor& filter, const Tensor& output,
|
||||
const Tensor& bias, const Tensor* side_input,
|
||||
se::StreamExecutor* stream_exec,
|
||||
absl::Span<const AutotuneResult> results);
|
||||
|
||||
// Returns the best algorithms for the config, one is the fastest, the other is
|
||||
// other is fastest with 0 scracth space. Unsuccessful autotuning results are
|
||||
// allowed and ignored.
|
||||
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
se::dnn::AlgorithmConfig* algo);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -548,6 +548,7 @@ def tf_additional_all_protos():
|
||||
def tf_protos_all_impl():
|
||||
return [
|
||||
"//tensorflow/core:autotuning_proto_cc_impl",
|
||||
"//tensorflow/core:conv_autotuning_proto_cc_impl",
|
||||
"//tensorflow/core:protos_all_cc_impl",
|
||||
]
|
||||
|
||||
|
19
tensorflow/core/protobuf/conv_autotuning.proto
Normal file
19
tensorflow/core/protobuf/conv_autotuning.proto
Normal file
@ -0,0 +1,19 @@
|
||||
// This is used for convolution logging. Also see
|
||||
// tensorflow/core/protobuf/autotuing.h
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "tensorflow/core/framework/node_def.proto";
|
||||
import "tensorflow/core/framework/tensor.proto";
|
||||
|
||||
message ConvNodeDef {
|
||||
NodeDef conv = 1;
|
||||
TensorProto input = 2;
|
||||
TensorProto filter = 3;
|
||||
TensorProto output = 4;
|
||||
TensorProto bias = 5;
|
||||
oneof side_input_oneof {
|
||||
TensorProto side_input = 6;
|
||||
}
|
||||
}
|
@ -464,15 +464,6 @@ tf_proto_library(
|
||||
provide_cc_alias = True,
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "logging_proto",
|
||||
srcs = ["logging.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = [":dnn_proto"],
|
||||
provide_cc_alias = True,
|
||||
visibility = [":friends"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dnn",
|
||||
srcs = ["dnn.cc"],
|
||||
|
@ -285,10 +285,8 @@ cc_library(
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:logger",
|
||||
"//tensorflow/stream_executor:dnn",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"//tensorflow/stream_executor:logging_proto_cc",
|
||||
"//tensorflow/stream_executor:plugin_registry",
|
||||
"//tensorflow/stream_executor:scratch_allocator",
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl_header",
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
|
||||
@ -39,7 +38,6 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/initialize.h"
|
||||
#include "tensorflow/stream_executor/lib/mathutil.h"
|
||||
#include "tensorflow/stream_executor/lib/threadpool.h"
|
||||
#include "tensorflow/stream_executor/logging.pb.h"
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
#include "tensorflow/stream_executor/scratch_allocator.h"
|
||||
@ -2621,63 +2619,6 @@ bool ShouldIncludeWinogradNonfusedAlgo(
|
||||
}
|
||||
#endif
|
||||
|
||||
dnn::ConvolutionProto GenerateConvProto(
|
||||
dnn::ConvolutionKind kind, dnn::DataType element_type,
|
||||
const dnn::BatchDescriptor& input_descriptor,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const dnn::BatchDescriptor& output_descriptor, dnn::AlgorithmDesc algorithm,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor, double conv_scale,
|
||||
double side_value_scale, dnn::DataType acc_type,
|
||||
dnn::ActivationMode activation) {
|
||||
dnn::ConvolutionProto conv_config;
|
||||
conv_config.set_kind(kind);
|
||||
*conv_config.mutable_input() = input_descriptor.ToProto(element_type);
|
||||
*conv_config.mutable_filter() = filter_descriptor.ToProto(element_type);
|
||||
*conv_config.mutable_output() = output_descriptor.ToProto(element_type);
|
||||
*conv_config.mutable_algorithm() = algorithm.ToProto();
|
||||
*conv_config.mutable_conv_desc() = convolution_descriptor.ToProto();
|
||||
conv_config.mutable_conv_desc()->set_compute_mode(acc_type);
|
||||
conv_config.set_conv_scale(conv_scale);
|
||||
conv_config.set_side_value_scale(side_value_scale);
|
||||
conv_config.set_activation(activation);
|
||||
return conv_config;
|
||||
}
|
||||
|
||||
void LogCudaProto(const dnn::ConvolutionProto& conv, float profile_time_ms,
|
||||
StreamExecutor* stream_executor) {
|
||||
{
|
||||
// For rolling-out, temporarily cap the number of logs per process.
|
||||
// TODO(timshen): remove it.
|
||||
static int count_down = 200;
|
||||
if (count_down == 0) {
|
||||
return;
|
||||
}
|
||||
count_down--;
|
||||
}
|
||||
|
||||
ConvLogEntry conv_log;
|
||||
*conv_log.mutable_convolution() = conv;
|
||||
conv_log.set_profile_time_ms(profile_time_ms);
|
||||
|
||||
auto info = conv_log.mutable_cuda_info();
|
||||
int cc_major, cc_minor;
|
||||
stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
||||
&cc_minor);
|
||||
info->mutable_compute_capability()->set_major(cc_major);
|
||||
info->mutable_compute_capability()->set_minor(cc_minor);
|
||||
|
||||
if (auto* dnn = stream_executor->AsDnn()) {
|
||||
port::StatusOr<dnn::VersionInfo> version_or = dnn->GetVersion();
|
||||
if (version_or.ok()) {
|
||||
const auto& version = version_or.ValueOrDie();
|
||||
info->mutable_cudnn_version()->set_major(version.major_version());
|
||||
info->mutable_cudnn_version()->set_minor(version.minor_version());
|
||||
info->mutable_cudnn_version()->set_patch(version.patch());
|
||||
}
|
||||
}
|
||||
tensorflow::Logger::Singleton()->LogProto(conv_log);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
port::Status CudnnSupport::DoPrepareForConvolution(
|
||||
@ -2971,13 +2912,6 @@ port::Status CudnnSupport::DoConvolve(
|
||||
output_profile_result->set_elapsed_time_in_ms(
|
||||
timer->GetElapsedMilliseconds());
|
||||
output_profile_result->set_scratch_size(scratch_memory.size());
|
||||
|
||||
LogCudaProto(
|
||||
GenerateConvProto(kind, element_type, input_descriptor,
|
||||
filter_descriptor, output_descriptor, algorithm_desc,
|
||||
convolution_descriptor, dalpha, dbeta,
|
||||
accumulator_type, dnn::ActivationMode::kNone),
|
||||
output_profile_result->elapsed_time_in_ms(), stream->parent());
|
||||
}
|
||||
|
||||
return port::Status::OK();
|
||||
@ -3095,14 +3029,6 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
|
||||
output_profile_result->set_elapsed_time_in_ms(
|
||||
timer->GetElapsedMilliseconds());
|
||||
output_profile_result->set_scratch_size(scratch.size());
|
||||
|
||||
LogCudaProto(
|
||||
GenerateConvProto(
|
||||
dnn::ConvolutionKind::FORWARD, dnn::ToDataType<ElementType>::value,
|
||||
conv_input_descriptor, filter_descriptor, output_descriptor,
|
||||
algo_desc, convolution_descriptor, conv_input_scale,
|
||||
side_input_scale, accumulator_type, activation_mode),
|
||||
output_profile_result->elapsed_time_in_ms(), stream->parent());
|
||||
}
|
||||
|
||||
return port::Status::OK();
|
||||
|
@ -108,22 +108,3 @@ message ConvolutionDescriptorProto {
|
||||
int32 group_count = 5;
|
||||
ConvolutionMode convolution_mode = 6;
|
||||
}
|
||||
|
||||
// A convolution. Currently it's only used for logging. In the future, we may
|
||||
// want to use it in the API as well.
|
||||
message ConvolutionProto {
|
||||
ConvolutionKind kind = 1;
|
||||
TensorDescriptorProto input = 2;
|
||||
TensorDescriptorProto filter = 3;
|
||||
TensorDescriptorProto output = 4;
|
||||
AlgorithmProto algorithm = 5;
|
||||
ConvolutionDescriptorProto conv_desc = 6;
|
||||
|
||||
// result = conv_scale * conv(...) + side_value_scale * side_value.
|
||||
// side_value is an arbitrary buffer if activation is not none. Otherwise, it
|
||||
// has to be the result buffer (using its old values).
|
||||
double conv_scale = 7;
|
||||
double side_value_scale = 8;
|
||||
|
||||
ActivationMode activation = 9;
|
||||
}
|
||||
|
@ -1,41 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package stream_executor;
|
||||
|
||||
import "tensorflow/stream_executor/dnn.proto";
|
||||
|
||||
message CudnnVersion {
|
||||
int32 major = 1;
|
||||
int32 minor = 2;
|
||||
int32 patch = 3;
|
||||
};
|
||||
|
||||
message ComputeCapability {
|
||||
int32 major = 1;
|
||||
int32 minor = 2;
|
||||
}
|
||||
|
||||
// NOTE: this proto is temporarily duplicated in other places, outside of
|
||||
// stream_executor. The plan is to move all custom logging (tensorflow::Logger
|
||||
// related) behavior out of StreamExecutor. There are two reasons:
|
||||
// * Technical: stream_executor is part of libtensorflow_framework.so. It's
|
||||
// extremely hard to have a single definition of the protos in the .so, and
|
||||
// let the callers call into those definitions. The complication lives in
|
||||
// cc_proto_library where we have a header-only version and impl version.
|
||||
// * Functional: we want to log autotuning stats from the callers. The
|
||||
// autotuning stats are not available in SE.
|
||||
//
|
||||
// TODO(timshen): remove this proto once both XLA and TF log autotuning
|
||||
// results.
|
||||
message CudaInfo {
|
||||
CudnnVersion cudnn_version = 1;
|
||||
ComputeCapability compute_capability = 2;
|
||||
}
|
||||
|
||||
message ConvLogEntry {
|
||||
CudaInfo cuda_info = 1;
|
||||
dnn.ConvolutionProto convolution = 2;
|
||||
|
||||
// Profiled time in ms. 0.0 if the convolution is not profiled.
|
||||
float profile_time_ms = 3;
|
||||
}
|
Loading…
Reference in New Issue
Block a user