From fc33a70c90e055a9fc5a81c1885f492688556ce0 Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Fri, 1 Mar 2019 17:14:14 -0800 Subject: [PATCH] Automated rollback of commit b562be27f705034d46e9920ebbcdd293485a7305 PiperOrigin-RevId: 236405963 --- tensorflow/compiler/xla/protobuf_util.h | 15 +++++ tensorflow/compiler/xla/service/gpu/BUILD | 14 ++--- .../xla/service/gpu}/autotuning.proto | 24 +++++--- .../gpu/cudnn_conv_algorithm_picker.cc | 60 +++++++++---------- .../service/gpu/cudnn_conv_algorithm_picker.h | 4 +- .../xla/service/gpu/gpu_autotuning.proto | 13 ---- tensorflow/core/BUILD | 1 - tensorflow/core/util/proto/BUILD | 2 - tensorflow/core/util/proto/proto_utils.h | 16 ----- 9 files changed, 66 insertions(+), 83 deletions(-) rename tensorflow/{core/protobuf => compiler/xla/service/gpu}/autotuning.proto (80%) delete mode 100644 tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index e20a7e95a63..4a88a48f285 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_PROTOBUF_UTIL_H_ +#include "google/protobuf/duration.pb.h" #include "absl/time/time.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -44,6 +45,20 @@ Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, // dirpath along as-is. void RegisterDirectoryExpander(const std::function& expander); +// Converts an absl::Duration to a google::protobuf::Duration. +inline google::protobuf::Duration ToDurationProto(absl::Duration duration) { + google::protobuf::Duration proto; + proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration)); + proto.set_nanos( + absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); + return proto; +} + +// Converts a google::protobuf::Duration to an absl::Duration. +inline absl::Duration FromDurationProto(google::protobuf::Duration proto) { + return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); +} + } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 3e4aefa55d3..3bc0daf9e70 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -440,14 +440,15 @@ cc_library( srcs = ["cudnn_conv_algorithm_picker.cc"], hdrs = ["cudnn_conv_algorithm_picker.h"], deps = [ + ":autotuning_proto", ":backend_configs", ":buffer_comparator", ":cudnn_conv_runner", - ":gpu_autotuning_proto", ":gpu_executable", ":ir_emission_utils", ":scratch_allocator", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:device_memory_allocator", "//tensorflow/compiler/xla/service:hlo", @@ -455,9 +456,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/core:lib", "//tensorflow/core:logger", - "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", - "//tensorflow/core/util/proto:proto_utils", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/time", @@ -777,7 +776,6 @@ cc_library( hdrs = ["gpu_transfer_manager.h"], deps = [ ":gpu_compiler", - ":infeed_manager", ":outfeed_manager", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:literal_util", @@ -790,6 +788,7 @@ cc_library( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:generic_transfer_manager", "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/compiler/xla/service/gpu:infeed_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@com_google_absl//absl/memory", @@ -1138,8 +1137,8 @@ tf_cc_test( srcs = ["cudnn_fused_conv_rewriter_test.cc"], tags = tf_cuda_tests_tags(), deps = [ - ":ir_emission_utils", "//tensorflow/compiler/xla/service:hlo_parser", + "//tensorflow/compiler/xla/service/gpu:ir_emission_utils", "//tensorflow/compiler/xla/service/gpu/tests:gpu_codegen_test", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/core:test", @@ -1184,11 +1183,10 @@ tf_cc_test( ) xla_proto_library( - name = "gpu_autotuning_proto", - srcs = ["gpu_autotuning.proto"], + name = "autotuning_proto", + srcs = ["autotuning.proto"], deps = [ "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo_proto", - "//tensorflow/core:protos_all_cc", ], ) diff --git a/tensorflow/core/protobuf/autotuning.proto b/tensorflow/compiler/xla/service/gpu/autotuning.proto similarity index 80% rename from tensorflow/core/protobuf/autotuning.proto rename to tensorflow/compiler/xla/service/gpu/autotuning.proto index 29e4d00a85f..b4a08963b4f 100644 --- a/tensorflow/core/protobuf/autotuning.proto +++ b/tensorflow/compiler/xla/service/gpu/autotuning.proto @@ -1,14 +1,15 @@ -// This file defines protos that store the results of autotuning various +// This file defines protos that store the results of autotuning XLA:GPU // operations. // // They are in proto format because we want to log them structured. They offer // tremendous statistical, testing, and debugging value. syntax = "proto3"; -package tensorflow; +package xla.gpu; -import "google/protobuf/any.proto"; import "google/protobuf/duration.proto"; +import "tensorflow/compiler/xla/xla_data.proto"; +import "tensorflow/compiler/xla/service/hlo.proto"; message CudnnVersion { int32 major = 1; @@ -62,12 +63,19 @@ message AutotuneResult { } } -message AutotuningLog { - google.protobuf.Any instr = 1; +message AutotuneLog { + message Instruction { + xla.HloInstructionProto instruction = 1; + repeated xla.ShapeProto operand_shapes = 2; + } + + oneof instr_oneof { + Instruction instr = 1; + } // Records all auto-tuning results per algorithm. - repeated AutotuneResult results = 2; + repeated AutotuneResult results = 3; - CudnnVersion cudnn_version = 3; - ComputeCapability compute_capability = 4; + CudnnVersion cudnn_version = 4; + ComputeCapability compute_capability = 5; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc index 02eb191cf58..0c4980f6549 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc @@ -14,23 +14,21 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h" -#include "google/protobuf/any.pb.h" #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "absl/time/time.h" #include "absl/types/optional.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h" #include "tensorflow/compiler/xla/service/gpu/buffer_comparator.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" -#include "tensorflow/compiler/xla/service/gpu/gpu_autotuning.pb.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/gpu/scratch_allocator.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logger.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/util/proto/proto_utils.h" namespace xla { namespace gpu { @@ -39,7 +37,6 @@ namespace { using absl::optional; using se::DeviceMemoryBase; using se::dnn::AlgorithmDesc; -using tensorflow::AutotuneResult; std::vector GetAlgorithms(CudnnConvKind kind, se::StreamExecutor* stream_exec) { @@ -97,8 +94,8 @@ tensorflow::mutex_lock LockGpu(const se::StreamExecutor* stream_exec) { return tensorflow::mutex_lock{it->second}; } -tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { - tensorflow::CudnnVersion cudnn_version; +xla::gpu::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { + xla::gpu::CudnnVersion cudnn_version; if (auto* dnn = stream_executor->AsDnn()) { StatusOr version_or = dnn->GetVersion(); if (version_or.ok()) { @@ -111,9 +108,9 @@ tensorflow::CudnnVersion GetCudnnVersion(se::StreamExecutor* stream_executor) { return cudnn_version; } -tensorflow::ComputeCapability GetComputeCapability( +xla::gpu::ComputeCapability GetComputeCapability( se::StreamExecutor* stream_executor) { - tensorflow::ComputeCapability cc; + xla::gpu::ComputeCapability cc; int cc_major, cc_minor; stream_executor->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); @@ -246,23 +243,25 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( RunCudnnConv(instr, absl::MakeSpan(operand_buffers), result_buffer, &scratch_allocator, &stream, options); - if (!launch_status.ok()) { - continue; - } - - if (!profile_result.is_valid()) { - continue; - } - profile_results.emplace_back(); AutotuneResult& result = profile_results.back(); result.mutable_conv()->set_algorithm(alg.algo_id()); result.mutable_conv()->set_tensor_ops_enabled(alg.tensor_ops_enabled()); + if (!launch_status.ok()) { + result.set_error_string(launch_status.error_message()); + continue; + } + + if (!profile_result.is_valid()) { + result.set_error_string("Invalid profile result"); + continue; + } + int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); result.mutable_success()->set_scratch_bytes(scratch_bytes_used); *result.mutable_success()->mutable_run_time() = - tensorflow::proto_utils::ToDurationProto( + protobuf_util::ToDurationProto( absl::Milliseconds(profile_result.elapsed_time_in_ms())); const bool crash_on_checking_failure = @@ -309,14 +308,10 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( // Log the autotuning result. { - tensorflow::AutotuningLog log; - { - ConvInstructionLog instr_log; - *instr_log.mutable_instruction() = instr->ToProto(); - for (const auto* op : instr->operands()) { - *instr_log.add_operand_shapes() = op->shape().ToProto(); - } - log.mutable_instr()->PackFrom(instr_log); + AutotuneLog log; + *log.mutable_instr()->mutable_instruction() = instr->ToProto(); + for (const auto* op : instr->operands()) { + *log.mutable_instr()->add_operand_shapes() = op->shape().ToProto(); } for (const auto& profile : profile_results) { *log.add_results() = profile; @@ -335,14 +330,13 @@ StatusOr CudnnConvAlgorithmPicker::PickBestAlgorithm( // The successful one should have a smaller key, since we are doing // min_element. If they are both unsuccessful, keep the earlier one in // the vector by comparing pointers. - return std::make_tuple(!lhs.has_success(), - tensorflow::proto_utils::FromDurationProto( - lhs.success().run_time()), - &lhs) < - std::make_tuple(!rhs.has_success(), - tensorflow::proto_utils::FromDurationProto( - rhs.success().run_time()), - &rhs); + return std::make_tuple( + !lhs.has_success(), + protobuf_util::FromDurationProto(lhs.success().run_time()), + &lhs) < std::make_tuple(!rhs.has_success(), + protobuf_util::FromDurationProto( + rhs.success().run_time()), + &rhs); }); if (best_result != profile_results_end && best_result->has_success()) { diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h index 6ab9c7a9ece..2e34ba96723 100644 --- a/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h +++ b/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.h @@ -20,12 +20,12 @@ limitations under the License. #include "absl/types/optional.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/autotuning.pb.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_conv_runner.h" #include "tensorflow/compiler/xla/service/hlo_instructions.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" -#include "tensorflow/core/protobuf/autotuning.pb.h" namespace xla { namespace gpu { @@ -50,7 +50,7 @@ class CudnnConvAlgorithmPicker : public HloModulePass { private: StatusOr RunOnComputation(HloComputation* computation); StatusOr RunOnInstruction(HloInstruction* instr); - StatusOr PickBestAlgorithm( + StatusOr PickBestAlgorithm( const HloCustomCallInstruction* instr); se::StreamExecutor* stream_exec_; // never null diff --git a/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto b/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto deleted file mode 100644 index ec4f6e9c913..00000000000 --- a/tensorflow/compiler/xla/service/gpu/gpu_autotuning.proto +++ /dev/null @@ -1,13 +0,0 @@ -// This is used for convolution logging. Also see -// tensorflow/core/protobuf/autotuing.h -syntax = "proto3"; - -package xla.gpu; - -import "tensorflow/compiler/xla/service/hlo.proto"; -import "tensorflow/compiler/xla/xla_data.proto"; - -message ConvInstructionLog { - xla.HloInstructionProto instruction = 1; - repeated xla.ShapeProto operand_shapes = 2; -} diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 777cc3d8185..0e1ed0ec7b7 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -233,7 +233,6 @@ CORE_PROTO_SRCS = COMMON_PROTO_SRCS + ERROR_CODES_PROTO_SRCS ADDITIONAL_CORE_PROTO_SRCS = [ "example/example_parser_configuration.proto", "protobuf/trackable_object_graph.proto", - "protobuf/autotuning.proto", "protobuf/control_flow.proto", # TODO(ebrevdo): Re-enable once CriticalSection is in core. # "protobuf/critical_section.proto", diff --git a/tensorflow/core/util/proto/BUILD b/tensorflow/core/util/proto/BUILD index 890bd837025..b990f0a7491 100644 --- a/tensorflow/core/util/proto/BUILD +++ b/tensorflow/core/util/proto/BUILD @@ -70,8 +70,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:platform_base", "@com_google_absl//absl/strings", - "@com_google_absl//absl/time", - "@protobuf_archive//:protobuf_headers", ], ) diff --git a/tensorflow/core/util/proto/proto_utils.h b/tensorflow/core/util/proto/proto_utils.h index ba45f8a5b0e..9451e317a13 100644 --- a/tensorflow/core/util/proto/proto_utils.h +++ b/tensorflow/core/util/proto/proto_utils.h @@ -16,9 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_ #define TENSORFLOW_CORE_UTIL_PROTO_PROTO_UTILS_H_ -#include "google/protobuf/duration.pb.h" #include "absl/strings/string_view.h" -#include "absl/time/time.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/protobuf.h" @@ -60,20 +58,6 @@ class StringErrorCollector : public protobuf::io::ErrorCollector { const int index_offset_; }; -// Converts an absl::Duration to a google::protobuf::Duration. -inline google::protobuf::Duration ToDurationProto(absl::Duration duration) { - google::protobuf::Duration proto; - proto.set_seconds(absl::IDivDuration(duration, absl::Seconds(1), &duration)); - proto.set_nanos( - absl::IDivDuration(duration, absl::Nanoseconds(1), &duration)); - return proto; -} - -// Converts a google::protobuf::Duration to an absl::Duration. -inline absl::Duration FromDurationProto(google::protobuf::Duration proto) { - return absl::Seconds(proto.seconds()) + absl::Nanoseconds(proto.nanos()); -} - } // namespace proto_utils } // namespace tensorflow