Log convolutions during Tensorflow GPU conv autotuning. Also removed the same functionality from StreamExecutor.
We decided to move the loggings from SE to TF and XLA for several reasons: * Proto formats already exist in TF and XLA that are suitable for logging. No need to create a third proto. * In TF and XLA autotuning stage, we also do/plan to do correctness checking. We want to log the checking results. * We are considering simplifying SE, so we prefer to keep SE simple for now. PiperOrigin-RevId: 236889526
This commit is contained in:
parent
69ba9fbc3f
commit
5aefc4e922
@ -68,6 +68,8 @@ tf_kernel_library(
|
||||
prefix = "fused_conv2d_bias_activation_op",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/core:autotuning_proto_cc",
|
||||
"//tensorflow/core:conv_autotuning_proto_cc",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
@ -92,6 +94,8 @@ tf_custom_op_library(
|
||||
"ops/fused_conv2d_bias_activation_op.cc",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/core:autotuning_proto_cc",
|
||||
"//tensorflow/core:conv_autotuning_proto_cc",
|
||||
"//tensorflow/core:lib_proto_parsing",
|
||||
"//tensorflow/core/kernels:bounds_check_lib",
|
||||
"//tensorflow/core/kernels:conv_2d_hdrs",
|
||||
|
@ -34,9 +34,16 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "google/protobuf/duration.pb.h"
|
||||
#include "absl/time/time.h"
|
||||
#include "cuda/include/cudnn.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/protobuf/conv_autotuning.pb.h"
|
||||
#include "tensorflow/core/util/activation_mode.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
@ -252,6 +259,131 @@ class FusedConv2DBiasActivationOp : public OpKernel {
|
||||
#if GOOGLE_CUDA
|
||||
namespace dnn = se::dnn;
|
||||
|
||||
// Several functions are copyed over from tensorflow/core/kernels/gpu_utils,
|
||||
// since this file may be compiled down to a tf_custom_op_library .so file,
|
||||
// which can't depend on basic dependencies like tensorflow/core:lib. Instead,
|
||||
// the code has to depend on whatever is the same in libtensorflow_framework.so.
|
||||
//
|
||||
// In theory, we can lift the dependencies of gpu_utils by turning it into a
|
||||
// template library that provides duck typing, but I think duplication is the
|
||||
// lesser of two evils.
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
|
||||
tensorflow::CudnnVersion cudnn_version;
|
||||
if (auto* dnn = stream_executor->AsDnn()) {
|
||||
se::port::StatusOr<se::dnn::VersionInfo> version_or = dnn->GetVersion();
|
||||
if (version_or.ok()) {
|
||||
const auto& version = version_or.ValueOrDie();
|
||||
cudnn_version.set_major(version.major_version());
|
||||
cudnn_version.set_minor(version.minor_version());
|
||||
cudnn_version.set_patch(version.patch());
|
||||
}
|
||||
}
|
||||
return cudnn_version;
|
||||
}
|
||||
|
||||
// Converts an absl::Duration to a google::protobuf::Duration.
|
||||
inline google::protobuf::Duration ToDurationProto(absl::Duration duration) {
|
||||
google::protobuf::Duration proto;
|
||||
proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration));
|
||||
proto.set_nanos(
|
||||
absl::IDivDuration(duration, absl::Nanoseconds(1), &duration));
|
||||
return proto;
|
||||
}
|
||||
|
||||
// Converts a google::protobuf::Duration to an absl::Duration.
|
||||
inline absl::Duration FromDurationProto(google::protobuf::Duration proto) {
|
||||
return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos());
|
||||
}
|
||||
|
||||
tensorflow::ComputeCapability GetComputeCapability(
|
||||
se::StreamExecutor* stream_executor) {
|
||||
tensorflow::ComputeCapability cc;
|
||||
int cc_major, cc_minor;
|
||||
stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
||||
&cc_minor);
|
||||
cc.set_major(cc_major);
|
||||
cc.set_minor(cc_minor);
|
||||
return cc;
|
||||
}
|
||||
|
||||
void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input,
|
||||
const Tensor& filter, const Tensor& output,
|
||||
const Tensor& bias, const Tensor* side_input,
|
||||
se::StreamExecutor* stream_exec,
|
||||
absl::Span<const AutotuneResult> results) {
|
||||
AutotuningLog log;
|
||||
ConvNodeDef instr;
|
||||
*instr.mutable_conv() = node;
|
||||
input.shape().AsProto(instr.mutable_input()->mutable_tensor_shape());
|
||||
instr.mutable_input()->set_dtype(input.dtype());
|
||||
filter.shape().AsProto(instr.mutable_filter()->mutable_tensor_shape());
|
||||
instr.mutable_filter()->set_dtype(filter.dtype());
|
||||
output.shape().AsProto(instr.mutable_output()->mutable_tensor_shape());
|
||||
instr.mutable_output()->set_dtype(output.dtype());
|
||||
bias.shape().AsProto(instr.mutable_bias()->mutable_tensor_shape());
|
||||
instr.mutable_bias()->set_dtype(bias.dtype());
|
||||
if (side_input) {
|
||||
side_input->shape().AsProto(
|
||||
instr.mutable_side_input()->mutable_tensor_shape());
|
||||
instr.mutable_side_input()->set_dtype(side_input->dtype());
|
||||
}
|
||||
log.mutable_instr()->PackFrom(std::move(instr));
|
||||
*log.mutable_cudnn_version() = internal::GetCudnnVersion(stream_exec);
|
||||
*log.mutable_compute_capability() =
|
||||
internal::GetComputeCapability(stream_exec);
|
||||
for (const auto& result : results) {
|
||||
*log.add_results() = result;
|
||||
}
|
||||
Logger::Singleton()->LogProto(log);
|
||||
}
|
||||
|
||||
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
se::dnn::AlgorithmConfig* algo) {
|
||||
// For the "!xhs.has_success()" below, this is because we want successful ones
|
||||
// to order first, therefore they need a smaller key per "min_element".
|
||||
const AutotuneResult* best_result = std::min_element(
|
||||
results.begin(), results.end(),
|
||||
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
||||
return std::make_tuple(
|
||||
!lhs.has_success(),
|
||||
internal::FromDurationProto(lhs.success().run_time())) <
|
||||
std::make_tuple(
|
||||
!rhs.has_success(),
|
||||
internal::FromDurationProto(rhs.success().run_time()));
|
||||
});
|
||||
|
||||
const AutotuneResult* best_result_no_scratch = std::min_element(
|
||||
results.begin(), results.end(),
|
||||
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
||||
return std::make_tuple(
|
||||
!lhs.has_success(), lhs.success().scratch_bytes(),
|
||||
internal::FromDurationProto(lhs.success().run_time())) <
|
||||
std::make_tuple(
|
||||
!rhs.has_success(), rhs.success().scratch_bytes(),
|
||||
internal::FromDurationProto(rhs.success().run_time()));
|
||||
});
|
||||
|
||||
if (best_result == results.end() || !best_result->has_success()) {
|
||||
return errors::NotFound("No algorithm worked!");
|
||||
}
|
||||
algo->set_algorithm({best_result->conv().algorithm(),
|
||||
best_result->conv().tensor_ops_enabled()});
|
||||
if (best_result_no_scratch != results.end() &&
|
||||
best_result_no_scratch->has_success() &&
|
||||
best_result_no_scratch->success().scratch_bytes() == 0) {
|
||||
algo->set_algorithm_no_scratch(
|
||||
{best_result_no_scratch->conv().algorithm(),
|
||||
best_result_no_scratch->conv().tensor_ops_enabled()});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
|
||||
// A dummy type to group forward convolution autotune results together.
|
||||
struct ConvBiasActivationAutoTuneGroup {
|
||||
static string name() { return "ConvBiasActivation"; }
|
||||
@ -579,8 +711,7 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
|
||||
}),
|
||||
algorithms.end());
|
||||
}
|
||||
dnn::ProfileResult best_result;
|
||||
dnn::ProfileResult best_result_no_scratch;
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
@ -597,28 +728,24 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.mutable_success()->set_scratch_bytes(
|
||||
scratch_allocator.TotalByteSize());
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
internal::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(ctx,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
internal::LogFusedConvAutotuneResults(ctx->op_kernel().def(), *conv_input,
|
||||
*filter, *output, bias, side_input,
|
||||
stream->parent(), results);
|
||||
OP_REQUIRES_OK(
|
||||
ctx, internal::BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
AutoTuneConvBiasActivation::GetInstance()->Insert(fused_conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
|
@ -2177,6 +2177,18 @@ tf_proto_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "conv_autotuning_proto",
|
||||
srcs = ["protobuf/conv_autotuning.proto"],
|
||||
cc_api_version = 2,
|
||||
default_header = True,
|
||||
protodeps = tf_additional_all_protos(),
|
||||
provide_cc_alias = True,
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
)
|
||||
|
||||
tf_proto_library_cc(
|
||||
name = "worker_proto",
|
||||
srcs = ["protobuf/worker.proto"],
|
||||
|
@ -268,7 +268,7 @@ tf_kernel_library(
|
||||
deps = [
|
||||
":eigen_helpers",
|
||||
":fill_functor",
|
||||
":gpu_util_hdrs",
|
||||
":gpu_utils",
|
||||
":image_resizer_state",
|
||||
":ops_util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
@ -462,11 +462,31 @@ cc_library(
|
||||
hdrs = ["conv_ops_gpu.h"],
|
||||
)
|
||||
|
||||
# We keep this target only because some contrib/ targets depend on it. The
|
||||
# reason why the contrib/ targets can't depend on gpu_utils is that, some
|
||||
# of the targets are tf_custom_op_library. tf_custom_op_library forbids the
|
||||
# dependency to tensorflow/core:lib, which gpu_utils certainly depends on.
|
||||
cc_library(
|
||||
name = "gpu_util_hdrs",
|
||||
hdrs = ["gpu_utils.h"],
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
name = "gpu_utils",
|
||||
srcs = ["gpu_utils.cc"],
|
||||
hdrs = ["gpu_utils.h"],
|
||||
deps = [
|
||||
":gpu_util_hdrs",
|
||||
"//tensorflow/core:autotuning_proto_cc",
|
||||
"//tensorflow/core:conv_autotuning_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:logger",
|
||||
"//tensorflow/core:stream_executor",
|
||||
"//tensorflow/core/util/proto:proto_utils",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "ops_util_test",
|
||||
size = "small",
|
||||
@ -1253,7 +1273,7 @@ tf_kernel_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":bounds_check_lib",
|
||||
":gpu_util_hdrs",
|
||||
":gpu_utils",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
@ -3339,7 +3359,7 @@ tf_kernel_library(
|
||||
}),
|
||||
deps = MATH_DEPS + [
|
||||
":eigen_contraction_kernel",
|
||||
":gpu_util_hdrs",
|
||||
":gpu_utils",
|
||||
] + select({
|
||||
":xsmm": ["@libxsmm_archive//:xsmm_avx"],
|
||||
"//conditions:default": [],
|
||||
@ -3761,6 +3781,7 @@ tf_kernel_library(
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core/util/proto:proto_utils",
|
||||
] + select({
|
||||
":xsmm_convolutions": [
|
||||
"@libxsmm_archive//:xsmm_avx",
|
||||
|
@ -51,6 +51,8 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace {
|
||||
@ -841,8 +843,7 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
|
||||
CHECK(stream->parent()->GetConvolveBackwardFilterAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
|
||||
&algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
@ -859,28 +860,23 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.mutable_success()->set_scratch_bytes(
|
||||
scratch_allocator.TotalByteSize());
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(ctx,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
LogConvAutotuneResults(ctx->op_kernel().def(), transformed_input,
|
||||
pre_transformed_filter_backprop,
|
||||
transformed_out_backprop, stream->parent(), results);
|
||||
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
AutoTuneConvBwdFilter::GetInstance()->Insert(conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
|
@ -51,6 +51,8 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace {
|
||||
@ -953,8 +955,7 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
|
||||
CHECK(stream->parent()->GetConvolveBackwardDataAlgorithms(
|
||||
conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
|
||||
&algorithms));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
@ -970,28 +971,23 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.mutable_success()->set_scratch_bytes(
|
||||
scratch_allocator.TotalByteSize());
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(ctx,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
LogConvAutotuneResults(ctx->op_kernel().def(), pre_transformed_in_backprop,
|
||||
transformed_filter, transformed_out_backprop,
|
||||
stream->parent(), results);
|
||||
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
AutoTuneConvBwdData::GetInstance()->Insert(conv_parameters,
|
||||
algorithm_config);
|
||||
}
|
||||
|
@ -55,6 +55,8 @@ limitations under the License.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
@ -855,8 +857,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
errors::Unknown("Failed to get convolution algorithm. This is probably "
|
||||
"because cuDNN failed to initialize, so try looking to "
|
||||
"see if a warning log message was printed above."));
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
@ -871,30 +872,22 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.mutable_success()->set_scratch_bytes(
|
||||
scratch_allocator.TotalByteSize());
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// TODO(yangzihao): refactor the profile result checking code into a common
|
||||
// utility function.
|
||||
OP_REQUIRES(ctx,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
LogConvAutotuneResults(ctx->op_kernel().def(), input, transformed_filter,
|
||||
transformed_output, stream->parent(), results);
|
||||
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
AutoTuneConv::GetInstance()->Insert(conv_parameters, algorithm_config);
|
||||
}
|
||||
|
||||
|
@ -34,6 +34,8 @@ limitations under the License.
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
using stream_executor::dnn::DimIndex;
|
||||
#endif
|
||||
|
||||
@ -445,8 +447,7 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
"because cuDNN failed to initialize, so try looking to "
|
||||
"see if a warning log message was printed above."));
|
||||
|
||||
ProfileResult best_result;
|
||||
ProfileResult best_result_no_scratch;
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
// TODO(zhengxq): profile each algorithm multiple times to better
|
||||
// accuracy.
|
||||
@ -461,28 +462,22 @@ struct LaunchConvOp<GPUDevice, T> {
|
||||
.ok();
|
||||
if (cudnn_launch_status) {
|
||||
if (profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
results.emplace_back();
|
||||
auto& result = results.back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.mutable_success()->set_scratch_bytes(
|
||||
scratch_allocator.TotalByteSize());
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
OP_REQUIRES(ctx,
|
||||
best_result.is_valid() || best_result_no_scratch.is_valid(),
|
||||
errors::NotFound("No algorithm worked!"));
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config.set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config.set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
LogConvAutotuneResults(ctx->op_kernel().def(), input, filter, *output,
|
||||
stream->parent(), results);
|
||||
OP_REQUIRES_OK(ctx, BestCudnnConvAlgorithm(results, &algorithm_config));
|
||||
AutoTuneConv3d::GetInstance()->Insert(conv_parameters, algorithm_config);
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This include can't be in the conv_ops_fused_impl.h headers. See b/62899350.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_fused_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This include can't be in the conv_ops_fused_impl.h headers. See b/62899350.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_fused_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -13,6 +13,10 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// This include can't be in the conv_ops_fused_impl.h headers. See b/62899350.
|
||||
#if GOOGLE_CUDA
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
#include "tensorflow/core/kernels/conv_ops_fused_impl.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
@ -59,10 +59,13 @@ limitations under the License.
|
||||
#include "cuda/include/cudnn.h"
|
||||
#include "tensorflow/core/kernels/conv_ops_gpu.h"
|
||||
#include "tensorflow/core/platform/stream_executor.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class AutotuneResult;
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
typedef Eigen::GpuDevice GPUDevice;
|
||||
|
||||
@ -497,10 +500,11 @@ inline int64 ConvolveScratchSize() {
|
||||
// algorithms and measuring execution time.
|
||||
// TODO(ezhulenev): Move it to conv_ops_gpu.h and share with conv_ops.cc.
|
||||
template <typename T, typename ConvLaunch>
|
||||
Status FindBestConvolveAlgorithm(const FusedConvParameters& params,
|
||||
const ConvLaunch launch,
|
||||
Status FindBestConvolveAlgorithm(
|
||||
const FusedConvParameters& params, const ConvLaunch launch,
|
||||
OpKernelContext* context, se::Stream* stream,
|
||||
se::dnn::AlgorithmConfig* algorithm_config) {
|
||||
se::dnn::AlgorithmConfig* algorithm_config,
|
||||
std::vector<tensorflow::AutotuneResult>* results) {
|
||||
// Check if we already have an algorithm selected for the given parameters.
|
||||
if (AutoTuneFusedConv::GetInstance()->Find(params, algorithm_config)) {
|
||||
return Status::OK();
|
||||
@ -517,9 +521,6 @@ Status FindBestConvolveAlgorithm(const FusedConvParameters& params,
|
||||
"see if a warning log message was printed above.");
|
||||
}
|
||||
|
||||
se::dnn::ProfileResult best_result;
|
||||
se::dnn::ProfileResult best_result_no_scratch;
|
||||
|
||||
for (auto profile_algorithm : algorithms) {
|
||||
DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
|
||||
se::dnn::ProfileResult profile_result;
|
||||
@ -529,29 +530,19 @@ Status FindBestConvolveAlgorithm(const FusedConvParameters& params,
|
||||
&profile_result);
|
||||
|
||||
if (cudnn_launch_status && profile_result.is_valid()) {
|
||||
if (profile_result.elapsed_time_in_ms() <
|
||||
best_result.elapsed_time_in_ms()) {
|
||||
best_result = profile_result;
|
||||
}
|
||||
if (scratch_allocator.TotalByteSize() == 0 &&
|
||||
profile_result.elapsed_time_in_ms() <
|
||||
best_result_no_scratch.elapsed_time_in_ms()) {
|
||||
best_result_no_scratch = profile_result;
|
||||
results->emplace_back();
|
||||
auto& result = results->back();
|
||||
result.mutable_conv()->set_algorithm(profile_algorithm.algo_id());
|
||||
result.mutable_conv()->set_tensor_ops_enabled(
|
||||
profile_algorithm.tensor_ops_enabled());
|
||||
result.mutable_success()->set_scratch_bytes(
|
||||
scratch_allocator.TotalByteSize());
|
||||
*result.mutable_success()->mutable_run_time() =
|
||||
proto_utils::ToDurationProto(
|
||||
absl::Milliseconds(profile_result.elapsed_time_in_ms()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!best_result.is_valid() && !best_result_no_scratch.is_valid()) {
|
||||
return errors::NotFound("No algorithm worked!");
|
||||
}
|
||||
if (best_result.is_valid()) {
|
||||
algorithm_config->set_algorithm(best_result.algorithm());
|
||||
}
|
||||
if (best_result_no_scratch.is_valid()) {
|
||||
algorithm_config->set_algorithm_no_scratch(
|
||||
best_result_no_scratch.algorithm());
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(BestCudnnConvAlgorithm(*results, algorithm_config));
|
||||
AutoTuneFusedConv::GetInstance()->Insert(params, *algorithm_config);
|
||||
return Status::OK();
|
||||
}
|
||||
@ -798,9 +789,14 @@ struct LaunchFusedConv2DOp<GPUDevice, T> {
|
||||
|
||||
se::dnn::AlgorithmConfig algorithm_config;
|
||||
if (cudnn_use_autotune) {
|
||||
OP_REQUIRES_OK(context, FindBestConvolveAlgorithm<T>(
|
||||
conv_parameters, launch, context, stream,
|
||||
&algorithm_config));
|
||||
std::vector<tensorflow::AutotuneResult> results;
|
||||
auto status =
|
||||
FindBestConvolveAlgorithm<T>(conv_parameters, launch, context, stream,
|
||||
&algorithm_config, &results);
|
||||
LogFusedConvAutotuneResults(context->op_kernel().def(), input,
|
||||
transformed_filter, transformed_output, bias,
|
||||
nullptr, stream->parent(), results);
|
||||
OP_REQUIRES_OK(context, status);
|
||||
}
|
||||
|
||||
DnnScratchAllocator scratch_allocator(ConvolveScratchSize(), context);
|
||||
|
153
tensorflow/core/kernels/gpu_utils.cc
Normal file
153
tensorflow/core/kernels/gpu_utils.cc
Normal file
@ -0,0 +1,153 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/gpu_utils.h"
|
||||
|
||||
#if GOOGLE_CUDA
|
||||
|
||||
#include "google/protobuf/any.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
#include "tensorflow/core/protobuf/autotuning.pb.h"
|
||||
#include "tensorflow/core/protobuf/conv_autotuning.pb.h"
|
||||
#include "tensorflow/core/util/proto/proto_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) {
|
||||
tensorflow::CudnnVersion cudnn_version;
|
||||
if (auto* dnn = stream_executor->AsDnn()) {
|
||||
se::port::StatusOr<se::dnn::VersionInfo> version_or = dnn->GetVersion();
|
||||
if (version_or.ok()) {
|
||||
const auto& version = version_or.ValueOrDie();
|
||||
cudnn_version.set_major(version.major_version());
|
||||
cudnn_version.set_minor(version.minor_version());
|
||||
cudnn_version.set_patch(version.patch());
|
||||
}
|
||||
}
|
||||
return cudnn_version;
|
||||
}
|
||||
|
||||
tensorflow::ComputeCapability GetComputeCapability(
|
||||
se::StreamExecutor* stream_executor) {
|
||||
tensorflow::ComputeCapability cc;
|
||||
int cc_major, cc_minor;
|
||||
stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
||||
&cc_minor);
|
||||
cc.set_major(cc_major);
|
||||
cc.set_minor(cc_minor);
|
||||
return cc;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void LogConvAutotuneResults(const NodeDef& node, const Tensor& input,
|
||||
const Tensor& filter, const Tensor& output,
|
||||
se::StreamExecutor* stream_exec,
|
||||
absl::Span<const AutotuneResult> results) {
|
||||
AutotuningLog log;
|
||||
ConvNodeDef instr;
|
||||
*instr.mutable_conv() = node;
|
||||
input.shape().AsProto(instr.mutable_input()->mutable_tensor_shape());
|
||||
instr.mutable_input()->set_dtype(input.dtype());
|
||||
filter.shape().AsProto(instr.mutable_filter()->mutable_tensor_shape());
|
||||
instr.mutable_filter()->set_dtype(filter.dtype());
|
||||
output.shape().AsProto(instr.mutable_output()->mutable_tensor_shape());
|
||||
instr.mutable_output()->set_dtype(output.dtype());
|
||||
log.mutable_instr()->PackFrom(std::move(instr));
|
||||
*log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
|
||||
*log.mutable_compute_capability() = GetComputeCapability(stream_exec);
|
||||
for (const auto& result : results) {
|
||||
*log.add_results() = result;
|
||||
}
|
||||
Logger::Singleton()->LogProto(log);
|
||||
}
|
||||
|
||||
void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input,
|
||||
const Tensor& filter, const Tensor& output,
|
||||
const Tensor& bias, const Tensor* side_input,
|
||||
se::StreamExecutor* stream_exec,
|
||||
absl::Span<const AutotuneResult> results) {
|
||||
AutotuningLog log;
|
||||
ConvNodeDef instr;
|
||||
*instr.mutable_conv() = node;
|
||||
input.shape().AsProto(instr.mutable_input()->mutable_tensor_shape());
|
||||
instr.mutable_input()->set_dtype(input.dtype());
|
||||
filter.shape().AsProto(instr.mutable_filter()->mutable_tensor_shape());
|
||||
instr.mutable_filter()->set_dtype(filter.dtype());
|
||||
output.shape().AsProto(instr.mutable_output()->mutable_tensor_shape());
|
||||
instr.mutable_output()->set_dtype(output.dtype());
|
||||
bias.shape().AsProto(instr.mutable_bias()->mutable_tensor_shape());
|
||||
instr.mutable_bias()->set_dtype(bias.dtype());
|
||||
if (side_input) {
|
||||
side_input->shape().AsProto(
|
||||
instr.mutable_side_input()->mutable_tensor_shape());
|
||||
instr.mutable_side_input()->set_dtype(side_input->dtype());
|
||||
}
|
||||
log.mutable_instr()->PackFrom(std::move(instr));
|
||||
*log.mutable_cudnn_version() = GetCudnnVersion(stream_exec);
|
||||
*log.mutable_compute_capability() = GetComputeCapability(stream_exec);
|
||||
for (const auto& result : results) {
|
||||
*log.add_results() = result;
|
||||
}
|
||||
Logger::Singleton()->LogProto(log);
|
||||
}
|
||||
|
||||
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
se::dnn::AlgorithmConfig* algo) {
|
||||
// For the "!xhs.has_success()" below, this is because we want successful ones
|
||||
// to order first, therefore they need a smaller key per "min_element".
|
||||
const AutotuneResult* best_result = std::min_element(
|
||||
results.begin(), results.end(),
|
||||
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
||||
return std::make_tuple(
|
||||
!lhs.has_success(),
|
||||
proto_utils::FromDurationProto(lhs.success().run_time())) <
|
||||
std::make_tuple(
|
||||
!rhs.has_success(),
|
||||
proto_utils::FromDurationProto(rhs.success().run_time()));
|
||||
});
|
||||
|
||||
const AutotuneResult* best_result_no_scratch = std::min_element(
|
||||
results.begin(), results.end(),
|
||||
[](const AutotuneResult& lhs, const AutotuneResult& rhs) {
|
||||
return std::make_tuple(
|
||||
!lhs.has_success(), lhs.success().scratch_bytes(),
|
||||
proto_utils::FromDurationProto(lhs.success().run_time())) <
|
||||
std::make_tuple(
|
||||
!rhs.has_success(), rhs.success().scratch_bytes(),
|
||||
proto_utils::FromDurationProto(rhs.success().run_time()));
|
||||
});
|
||||
|
||||
if (best_result == results.end() || !best_result->has_success()) {
|
||||
return errors::NotFound("No algorithm worked!");
|
||||
}
|
||||
algo->set_algorithm({best_result->conv().algorithm(),
|
||||
best_result->conv().tensor_ops_enabled()});
|
||||
if (best_result_no_scratch != results.end() &&
|
||||
best_result_no_scratch->has_success() &&
|
||||
best_result_no_scratch->success().scratch_bytes() == 0) {
|
||||
algo->set_algorithm_no_scratch(
|
||||
{best_result_no_scratch->conv().algorithm(),
|
||||
best_result_no_scratch->conv().tensor_ops_enabled()});
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
@ -20,6 +20,9 @@ limitations under the License.
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
@ -28,6 +31,9 @@ limitations under the License.
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class NodeDef;
|
||||
class AutotuneResult;
|
||||
|
||||
template <typename T>
|
||||
inline se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
|
||||
se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
|
||||
@ -156,6 +162,25 @@ class AutoTuneSingleton {
|
||||
}
|
||||
};
|
||||
|
||||
// Logs convolution results to customized back-storage.
|
||||
void LogConvAutotuneResults(const NodeDef& node, const Tensor& input,
|
||||
const Tensor& filter, const Tensor& output,
|
||||
se::StreamExecutor* stream_exec,
|
||||
absl::Span<const AutotuneResult> results);
|
||||
|
||||
// Logs fused convolution results to customized back-storage.
|
||||
void LogFusedConvAutotuneResults(const NodeDef& node, const Tensor& input,
|
||||
const Tensor& filter, const Tensor& output,
|
||||
const Tensor& bias, const Tensor* side_input,
|
||||
se::StreamExecutor* stream_exec,
|
||||
absl::Span<const AutotuneResult> results);
|
||||
|
||||
// Returns the best algorithms for the config, one is the fastest, the other is
|
||||
// other is fastest with 0 scracth space. Unsuccessful autotuning results are
|
||||
// allowed and ignored.
|
||||
Status BestCudnnConvAlgorithm(absl::Span<const AutotuneResult> results,
|
||||
se::dnn::AlgorithmConfig* algo);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -548,6 +548,7 @@ def tf_additional_all_protos():
|
||||
def tf_protos_all_impl():
|
||||
return [
|
||||
"//tensorflow/core:autotuning_proto_cc_impl",
|
||||
"//tensorflow/core:conv_autotuning_proto_cc_impl",
|
||||
"//tensorflow/core:protos_all_cc_impl",
|
||||
]
|
||||
|
||||
|
19
tensorflow/core/protobuf/conv_autotuning.proto
Normal file
19
tensorflow/core/protobuf/conv_autotuning.proto
Normal file
@ -0,0 +1,19 @@
|
||||
// This is used for convolution logging. Also see
|
||||
// tensorflow/core/protobuf/autotuing.h
|
||||
syntax = "proto3";
|
||||
|
||||
package tensorflow;
|
||||
|
||||
import "tensorflow/core/framework/node_def.proto";
|
||||
import "tensorflow/core/framework/tensor.proto";
|
||||
|
||||
message ConvNodeDef {
|
||||
NodeDef conv = 1;
|
||||
TensorProto input = 2;
|
||||
TensorProto filter = 3;
|
||||
TensorProto output = 4;
|
||||
TensorProto bias = 5;
|
||||
oneof side_input_oneof {
|
||||
TensorProto side_input = 6;
|
||||
}
|
||||
}
|
@ -464,15 +464,6 @@ tf_proto_library(
|
||||
provide_cc_alias = True,
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
name = "logging_proto",
|
||||
srcs = ["logging.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = [":dnn_proto"],
|
||||
provide_cc_alias = True,
|
||||
visibility = [":friends"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "dnn",
|
||||
srcs = ["dnn.cc"],
|
||||
|
@ -285,10 +285,8 @@ cc_library(
|
||||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:logger",
|
||||
"//tensorflow/stream_executor:dnn",
|
||||
"//tensorflow/stream_executor:event",
|
||||
"//tensorflow/stream_executor:logging_proto_cc",
|
||||
"//tensorflow/stream_executor:plugin_registry",
|
||||
"//tensorflow/stream_executor:scratch_allocator",
|
||||
"//tensorflow/stream_executor:stream_executor_pimpl_header",
|
||||
|
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/strings/stringprintf.h"
|
||||
#include "tensorflow/core/platform/logger.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
|
||||
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
|
||||
@ -39,7 +38,6 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/initialize.h"
|
||||
#include "tensorflow/stream_executor/lib/mathutil.h"
|
||||
#include "tensorflow/stream_executor/lib/threadpool.h"
|
||||
#include "tensorflow/stream_executor/logging.pb.h"
|
||||
#include "tensorflow/stream_executor/platform/logging.h"
|
||||
#include "tensorflow/stream_executor/plugin_registry.h"
|
||||
#include "tensorflow/stream_executor/scratch_allocator.h"
|
||||
@ -2621,63 +2619,6 @@ bool ShouldIncludeWinogradNonfusedAlgo(
|
||||
}
|
||||
#endif
|
||||
|
||||
dnn::ConvolutionProto GenerateConvProto(
|
||||
dnn::ConvolutionKind kind, dnn::DataType element_type,
|
||||
const dnn::BatchDescriptor& input_descriptor,
|
||||
const dnn::FilterDescriptor& filter_descriptor,
|
||||
const dnn::BatchDescriptor& output_descriptor, dnn::AlgorithmDesc algorithm,
|
||||
const dnn::ConvolutionDescriptor& convolution_descriptor, double conv_scale,
|
||||
double side_value_scale, dnn::DataType acc_type,
|
||||
dnn::ActivationMode activation) {
|
||||
dnn::ConvolutionProto conv_config;
|
||||
conv_config.set_kind(kind);
|
||||
*conv_config.mutable_input() = input_descriptor.ToProto(element_type);
|
||||
*conv_config.mutable_filter() = filter_descriptor.ToProto(element_type);
|
||||
*conv_config.mutable_output() = output_descriptor.ToProto(element_type);
|
||||
*conv_config.mutable_algorithm() = algorithm.ToProto();
|
||||
*conv_config.mutable_conv_desc() = convolution_descriptor.ToProto();
|
||||
conv_config.mutable_conv_desc()->set_compute_mode(acc_type);
|
||||
conv_config.set_conv_scale(conv_scale);
|
||||
conv_config.set_side_value_scale(side_value_scale);
|
||||
conv_config.set_activation(activation);
|
||||
return conv_config;
|
||||
}
|
||||
|
||||
void LogCudaProto(const dnn::ConvolutionProto& conv, float profile_time_ms,
|
||||
StreamExecutor* stream_executor) {
|
||||
{
|
||||
// For rolling-out, temporarily cap the number of logs per process.
|
||||
// TODO(timshen): remove it.
|
||||
static int count_down = 200;
|
||||
if (count_down == 0) {
|
||||
return;
|
||||
}
|
||||
count_down--;
|
||||
}
|
||||
|
||||
ConvLogEntry conv_log;
|
||||
*conv_log.mutable_convolution() = conv;
|
||||
conv_log.set_profile_time_ms(profile_time_ms);
|
||||
|
||||
auto info = conv_log.mutable_cuda_info();
|
||||
int cc_major, cc_minor;
|
||||
stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major,
|
||||
&cc_minor);
|
||||
info->mutable_compute_capability()->set_major(cc_major);
|
||||
info->mutable_compute_capability()->set_minor(cc_minor);
|
||||
|
||||
if (auto* dnn = stream_executor->AsDnn()) {
|
||||
port::StatusOr<dnn::VersionInfo> version_or = dnn->GetVersion();
|
||||
if (version_or.ok()) {
|
||||
const auto& version = version_or.ValueOrDie();
|
||||
info->mutable_cudnn_version()->set_major(version.major_version());
|
||||
info->mutable_cudnn_version()->set_minor(version.minor_version());
|
||||
info->mutable_cudnn_version()->set_patch(version.patch());
|
||||
}
|
||||
}
|
||||
tensorflow::Logger::Singleton()->LogProto(conv_log);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
port::Status CudnnSupport::DoPrepareForConvolution(
|
||||
@ -2971,13 +2912,6 @@ port::Status CudnnSupport::DoConvolve(
|
||||
output_profile_result->set_elapsed_time_in_ms(
|
||||
timer->GetElapsedMilliseconds());
|
||||
output_profile_result->set_scratch_size(scratch_memory.size());
|
||||
|
||||
LogCudaProto(
|
||||
GenerateConvProto(kind, element_type, input_descriptor,
|
||||
filter_descriptor, output_descriptor, algorithm_desc,
|
||||
convolution_descriptor, dalpha, dbeta,
|
||||
accumulator_type, dnn::ActivationMode::kNone),
|
||||
output_profile_result->elapsed_time_in_ms(), stream->parent());
|
||||
}
|
||||
|
||||
return port::Status::OK();
|
||||
@ -3095,14 +3029,6 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
|
||||
output_profile_result->set_elapsed_time_in_ms(
|
||||
timer->GetElapsedMilliseconds());
|
||||
output_profile_result->set_scratch_size(scratch.size());
|
||||
|
||||
LogCudaProto(
|
||||
GenerateConvProto(
|
||||
dnn::ConvolutionKind::FORWARD, dnn::ToDataType<ElementType>::value,
|
||||
conv_input_descriptor, filter_descriptor, output_descriptor,
|
||||
algo_desc, convolution_descriptor, conv_input_scale,
|
||||
side_input_scale, accumulator_type, activation_mode),
|
||||
output_profile_result->elapsed_time_in_ms(), stream->parent());
|
||||
}
|
||||
|
||||
return port::Status::OK();
|
||||
|
@ -108,22 +108,3 @@ message ConvolutionDescriptorProto {
|
||||
int32 group_count = 5;
|
||||
ConvolutionMode convolution_mode = 6;
|
||||
}
|
||||
|
||||
// A convolution. Currently it's only used for logging. In the future, we may
|
||||
// want to use it in the API as well.
|
||||
message ConvolutionProto {
|
||||
ConvolutionKind kind = 1;
|
||||
TensorDescriptorProto input = 2;
|
||||
TensorDescriptorProto filter = 3;
|
||||
TensorDescriptorProto output = 4;
|
||||
AlgorithmProto algorithm = 5;
|
||||
ConvolutionDescriptorProto conv_desc = 6;
|
||||
|
||||
// result = conv_scale * conv(...) + side_value_scale * side_value.
|
||||
// side_value is an arbitrary buffer if activation is not none. Otherwise, it
|
||||
// has to be the result buffer (using its old values).
|
||||
double conv_scale = 7;
|
||||
double side_value_scale = 8;
|
||||
|
||||
ActivationMode activation = 9;
|
||||
}
|
||||
|
@ -1,41 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package stream_executor;
|
||||
|
||||
import "tensorflow/stream_executor/dnn.proto";
|
||||
|
||||
message CudnnVersion {
|
||||
int32 major = 1;
|
||||
int32 minor = 2;
|
||||
int32 patch = 3;
|
||||
};
|
||||
|
||||
message ComputeCapability {
|
||||
int32 major = 1;
|
||||
int32 minor = 2;
|
||||
}
|
||||
|
||||
// NOTE: this proto is temporarily duplicated in other places, outside of
|
||||
// stream_executor. The plan is to move all custom logging (tensorflow::Logger
|
||||
// related) behavior out of StreamExecutor. There are two reasons:
|
||||
// * Technical: stream_executor is part of libtensorflow_framework.so. It's
|
||||
// extremely hard to have a single definition of the protos in the .so, and
|
||||
// let the callers call into those definitions. The complication lives in
|
||||
// cc_proto_library where we have a header-only version and impl version.
|
||||
// * Functional: we want to log autotuning stats from the callers. The
|
||||
// autotuning stats are not available in SE.
|
||||
//
|
||||
// TODO(timshen): remove this proto once both XLA and TF log autotuning
|
||||
// results.
|
||||
message CudaInfo {
|
||||
CudnnVersion cudnn_version = 1;
|
||||
ComputeCapability compute_capability = 2;
|
||||
}
|
||||
|
||||
message ConvLogEntry {
|
||||
CudaInfo cuda_info = 1;
|
||||
dnn.ConvolutionProto convolution = 2;
|
||||
|
||||
// Profiled time in ms. 0.0 if the convolution is not profiled.
|
||||
float profile_time_ms = 3;
|
||||
}
|
Loading…
Reference in New Issue
Block a user