From 1a287cbaeeda0194f6269c63ce082663fb311e97 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Feb 2020 18:06:41 -0800 Subject: [PATCH] Add profiler service to eager context so that it is started by default. PiperOrigin-RevId: 294348740 Change-Id: I99eb430c8a5c1c35ad442987a7c50af3f1f92e29 --- tensorflow/c/eager/c_api_experimental.cc | 8 ++-- tensorflow/core/common_runtime/eager/BUILD | 1 + .../core/common_runtime/eager/context.cc | 3 ++ .../core/common_runtime/eager/context.h | 5 +++ tensorflow/core/profiler/rpc/BUILD | 1 + .../core/profiler/rpc/profiler_server.cc | 37 ++++++++++++++++--- .../core/profiler/rpc/profiler_server.h | 15 +++++++- 7 files changed, 61 insertions(+), 9 deletions(-) diff --git a/tensorflow/c/eager/c_api_experimental.cc b/tensorflow/c/eager/c_api_experimental.cc index 8d3f86f9c74..71cade94e6c 100644 --- a/tensorflow/c/eager/c_api_experimental.cc +++ b/tensorflow/c/eager/c_api_experimental.cc @@ -48,9 +48,11 @@ void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { } void TFE_StartProfilerServer(int port) { - // Release child thread intentionally. The child thread can be terminated by - // terminating the main thread. - tensorflow::StartProfilerServer(port).release(); + auto profiler_server = absl::make_unique(); + profiler_server->StartProfilerServer(port); + // Release child server thread intentionally. The child thread can be + // terminated when the main program exits. + profiler_server.release(); } void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index f3fd1f45eed..9a962fd06df 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -77,6 +77,7 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:server_lib", "//tensorflow/core/distributed_runtime:worker_session", "//tensorflow/core/distributed_runtime/eager:eager_client", + "//tensorflow/core/profiler/rpc:profiler_server", ], }), ) diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 5932ed4b698..fe16ec12c70 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -45,6 +45,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h" #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h" +#include "tensorflow/core/profiler/rpc/profiler_server.h" #endif // !IS_MOBILE_PLATFORM #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/lib/core/blocking_counter.h" @@ -110,6 +111,8 @@ EagerContext::EagerContext( #if !defined(IS_MOBILE_PLATFORM) context_id_ = kInvalidContextId; + profiler_server_ = absl::make_unique(); + profiler_server_->MaybeStartProfilerServer(); #endif // IS_MOBILE_PLATFORM std::unique_ptr drl( diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 094e7fd8b49..06d401352d7 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -73,6 +73,8 @@ namespace eager { class RemoteMgr; } // namespace eager +class ProfilerServer; + // LINT.IfChange // Note: Keep in sync with exported copy of enum in eager/c_api.h. enum ContextDevicePlacementPolicy { @@ -599,6 +601,9 @@ class EagerContext : public core::RefCounted { std::shared_ptr worker_session_; std::unique_ptr remote_eager_workers_; + // Starts a thread for profiling service. + std::unique_ptr profiler_server_; + mutex remote_state_mu_; uint64 context_id_ GUARDED_BY(remote_state_mu_); diff --git a/tensorflow/core/profiler/rpc/BUILD b/tensorflow/core/profiler/rpc/BUILD index 72ff3fc2306..a810f706c2e 100644 --- a/tensorflow/core/profiler/rpc/BUILD +++ b/tensorflow/core/profiler/rpc/BUILD @@ -30,6 +30,7 @@ cc_library( "//tensorflow:grpc++", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core/profiler:profiler_service_proto_cc", "@com_google_absl//absl/strings", ], diff --git a/tensorflow/core/profiler/rpc/profiler_server.cc b/tensorflow/core/profiler/rpc/profiler_server.cc index faa83b9099f..477a2490028 100644 --- a/tensorflow/core/profiler/rpc/profiler_server.cc +++ b/tensorflow/core/profiler/rpc/profiler_server.cc @@ -23,13 +23,14 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/profiler/profiler_service.grpc.pb.h" #include "tensorflow/core/profiler/rpc/profiler_service_impl.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { -std::unique_ptr StartProfilerServer(int32 port) { +void ProfilerServer::StartProfilerServer(int32 port) { Env* env = Env::Default(); - return WrapUnique(env->StartThread({}, "profiler server", [port]() { + auto start_server = [port, this]() { string server_address = absl::StrCat("0.0.0.0:", port); std::unique_ptr service = CreateProfilerService(); @@ -37,10 +38,36 @@ std::unique_ptr StartProfilerServer(int32 port) { builder.AddListeningPort(server_address, ::grpc::InsecureServerCredentials()); builder.RegisterService(service.get()); - std::unique_ptr<::grpc::Server> server(builder.BuildAndStart()); + server_ = builder.BuildAndStart(); LOG(INFO) << "Profiling Server listening on " << server_address; - server->Wait(); - })); + server_->Wait(); + }; + server_thread_ = + WrapUnique(env->StartThread({}, "ProfilerServer", start_server)); +} + +void ProfilerServer::MaybeStartProfilerServer() { + int64 profiler_port; + // The implementation of ReadInt64FromEnvVar guaranteed that the output + // argument will be set to default value failure. + Status s = ReadInt64FromEnvVar("TF_PROFILER_PORT", -1, &profiler_port); + if (!s.ok()) { + LOG(WARNING) << "StartProfilerServer: " << s.error_message(); + } + if (profiler_port < 1024 || profiler_port > 49151) { + // Disable the log message if profiler_port is -1 to prevent spam the + // terminal for TF user who doesn't set a profiler port. + if (profiler_port == -1) return; + LOG(WARNING) + << "Profiler server not started. TF_PROFILER_PORT: " << profiler_port + << " is out of the valid registered port range (1024 to 49151)."; + return; + } + StartProfilerServer(profiler_port); +} + +ProfilerServer::~ProfilerServer() { + if (server_) server_->Shutdown(); } } // namespace tensorflow diff --git a/tensorflow/core/profiler/rpc/profiler_server.h b/tensorflow/core/profiler/rpc/profiler_server.h index fd516121799..26e9606e2c5 100644 --- a/tensorflow/core/profiler/rpc/profiler_server.h +++ b/tensorflow/core/profiler/rpc/profiler_server.h @@ -17,13 +17,26 @@ limitations under the License. #include +#include "grpcpp/grpcpp.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { class Thread; -std::unique_ptr StartProfilerServer(int32 port); +class ProfilerServer { + public: + ~ProfilerServer(); + // If TF_PROFILER_PORT is defined, starts a profiler server with the + // specified port. Otherwise, don't start a profiler server + void MaybeStartProfilerServer(); + // Starts a profiler server with a given port. + void StartProfilerServer(int32 port); + + private: + std::unique_ptr<::grpc::Server> server_; + std::unique_ptr server_thread_; +}; } // namespace tensorflow