From 19292fe67fd9f4748f6d586529a15402db6ad388 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 10 Jun 2019 15:58:55 -0700 Subject: [PATCH] Refactor StartMonitoring to stream back result. PiperOrigin-RevId: 252503898 --- tensorflow/contrib/tpu/profiler/BUILD | 1 - .../tpu/profiler/capture_tpu_profile.cc | 27 ++++++++++++-- .../profiler/rpc/client/capture_profile.cc | 35 ++++++++----------- .../profiler/rpc/client/capture_profile.h | 7 ++-- 4 files changed, 42 insertions(+), 28 deletions(-) diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index 461f9856b0d..c7a8380b837 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -3,7 +3,6 @@ package( ) load("//tensorflow:tensorflow.bzl", "tf_cc_binary") -load("//tensorflow:tensorflow.bzl", "tf_cc_test") cc_library( name = "version", diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc index 5087d2534cd..c9e73c1a5d5 100644 --- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc +++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc @@ -19,10 +19,31 @@ limitations under the License. // receives and dumps the profile data to a tensorboard log directory. #include "tensorflow/contrib/tpu/profiler/version.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/init_main.h" #include "tensorflow/core/profiler/rpc/client/capture_profile.h" #include "tensorflow/core/util/command_line_flags.h" +namespace tensorflow { +namespace { + +Status MonitoringHelper(const string& service_addr, int duration_ms, + int monitoring_level, bool timestamp, int num_queries) { + for (int query = 0; query < num_queries; ++query) { + string result; + TF_RETURN_IF_ERROR(tensorflow::profiler::client::StartMonitoring( + service_addr, duration_ms, monitoring_level, timestamp, &result)); + std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1 + << "):\n\n" + << result << std::flush; + } + return Status::OK(); +} + +} // namespace +} // namespace tensorflow + int main(int argc, char** argv) { tensorflow::string FLAGS_service_addr; tensorflow::string FLAGS_logdir; @@ -105,9 +126,9 @@ int main(int argc, char** argv) { << FLAGS_service_addr << " for " << duration_ms << "ms and show metrics for " << num_queries << " time(s)." << std::endl; - status = tensorflow::profiler::client::StartMonitoring( - FLAGS_service_addr, duration_ms, FLAGS_monitoring_level, - FLAGS_timestamp, num_queries); + status = tensorflow::MonitoringHelper(FLAGS_service_addr, duration_ms, + FLAGS_monitoring_level, + FLAGS_timestamp, num_queries); } else { status = tensorflow::profiler::client::StartTracing( FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list, diff --git a/tensorflow/core/profiler/rpc/client/capture_profile.cc b/tensorflow/core/profiler/rpc/client/capture_profile.cc index 06a97b4874e..98378ded2d8 100644 --- a/tensorflow/core/profiler/rpc/client/capture_profile.cc +++ b/tensorflow/core/profiler/rpc/client/capture_profile.cc @@ -245,27 +245,22 @@ MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level, } Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms, - int monitoring_level, bool timestamp, int num_queries) { - for (int query = 0; query < num_queries; ++query) { - MonitorRequest request = - PopulateMonitorRequest(duration_ms, monitoring_level, timestamp); + int monitoring_level, bool timestamp, string* result) { + MonitorRequest request = + PopulateMonitorRequest(duration_ms, monitoring_level, timestamp); - ::grpc::ClientContext context; - ::grpc::ChannelArguments channel_args; - channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, - std::numeric_limits::max()); - std::unique_ptr stub = - grpc::ProfilerService::NewStub(::grpc::CreateCustomChannel( - "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), - channel_args)); - MonitorResponse response; - TF_RETURN_IF_ERROR( - FromGrpcStatus(stub->Monitor(&context, request, &response))); - - std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1 - << "):\n\n" - << response.data() << std::flush; - } + ::grpc::ClientContext context; + ::grpc::ChannelArguments channel_args; + channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH, + std::numeric_limits::max()); + std::unique_ptr stub = + grpc::ProfilerService::NewStub(::grpc::CreateCustomChannel( + "dns:///" + service_addr, ::grpc::InsecureChannelCredentials(), + channel_args)); + MonitorResponse response; + TF_RETURN_IF_ERROR( + FromGrpcStatus(stub->Monitor(&context, request, &response))); + *result = response.data(); return Status::OK(); } diff --git a/tensorflow/core/profiler/rpc/client/capture_profile.h b/tensorflow/core/profiler/rpc/client/capture_profile.h index 247e5c8c036..04c418e97f4 100644 --- a/tensorflow/core/profiler/rpc/client/capture_profile.h +++ b/tensorflow/core/profiler/rpc/client/capture_profile.h @@ -25,11 +25,10 @@ namespace client { Status ValidateHostPortPair(const string& host_port); -// Repeatedly collects profiles and shows user-friendly metrics for -// 'num_queries' time(s). If timestamp flag is true, timestamp will be -// displayed in "%H:%M:%S" format. +// Collects one sample of monitoring profile and shows user-friendly metrics. +// If timestamp flag is true, timestamp will be displayed in "%H:%M:%S" format. Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms, - int monitoring_level, bool timestamp, int num_queries); + int monitoring_level, bool timestamp, string* output); // Starts tracing on a single or multiple hosts and saves the result in the // given logdir. If no trace was collected, retries tracing for