Refactor StartMonitoring to stream back result.

PiperOrigin-RevId: 252503898
This commit is contained in:
A. Unique TensorFlower 2019-06-10 15:58:55 -07:00 committed by TensorFlower Gardener
parent d4391586b5
commit 19292fe67f
4 changed files with 42 additions and 28 deletions

View File

@ -3,7 +3,6 @@ package(
)
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
cc_library(
name = "version",

View File

@ -19,10 +19,31 @@ limitations under the License.
// receives and dumps the profile data to a tensorboard log directory.
#include "tensorflow/contrib/tpu/profiler/version.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace tensorflow {
namespace {
Status MonitoringHelper(const string& service_addr, int duration_ms,
int monitoring_level, bool timestamp, int num_queries) {
for (int query = 0; query < num_queries; ++query) {
string result;
TF_RETURN_IF_ERROR(tensorflow::profiler::client::StartMonitoring(
service_addr, duration_ms, monitoring_level, timestamp, &result));
std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
<< "):\n\n"
<< result << std::flush;
}
return Status::OK();
}
} // namespace
} // namespace tensorflow
int main(int argc, char** argv) {
tensorflow::string FLAGS_service_addr;
tensorflow::string FLAGS_logdir;
@ -105,9 +126,9 @@ int main(int argc, char** argv) {
<< FLAGS_service_addr << " for " << duration_ms
<< "ms and show metrics for " << num_queries << " time(s)."
<< std::endl;
status = tensorflow::profiler::client::StartMonitoring(
FLAGS_service_addr, duration_ms, FLAGS_monitoring_level,
FLAGS_timestamp, num_queries);
status = tensorflow::MonitoringHelper(FLAGS_service_addr, duration_ms,
FLAGS_monitoring_level,
FLAGS_timestamp, num_queries);
} else {
status = tensorflow::profiler::client::StartTracing(
FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list,

View File

@ -245,27 +245,22 @@ MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level,
}
Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
int monitoring_level, bool timestamp, int num_queries) {
for (int query = 0; query < num_queries; ++query) {
MonitorRequest request =
PopulateMonitorRequest(duration_ms, monitoring_level, timestamp);
int monitoring_level, bool timestamp, string* result) {
MonitorRequest request =
PopulateMonitorRequest(duration_ms, monitoring_level, timestamp);
::grpc::ClientContext context;
::grpc::ChannelArguments channel_args;
channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
std::numeric_limits<int32>::max());
std::unique_ptr<grpc::ProfilerService::Stub> stub =
grpc::ProfilerService::NewStub(::grpc::CreateCustomChannel(
"dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
channel_args));
MonitorResponse response;
TF_RETURN_IF_ERROR(
FromGrpcStatus(stub->Monitor(&context, request, &response)));
std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
<< "):\n\n"
<< response.data() << std::flush;
}
::grpc::ClientContext context;
::grpc::ChannelArguments channel_args;
channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
std::numeric_limits<int32>::max());
std::unique_ptr<grpc::ProfilerService::Stub> stub =
grpc::ProfilerService::NewStub(::grpc::CreateCustomChannel(
"dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
channel_args));
MonitorResponse response;
TF_RETURN_IF_ERROR(
FromGrpcStatus(stub->Monitor(&context, request, &response)));
*result = response.data();
return Status::OK();
}

View File

@ -25,11 +25,10 @@ namespace client {
Status ValidateHostPortPair(const string& host_port);
// Repeatedly collects profiles and shows user-friendly metrics for
// 'num_queries' time(s). If timestamp flag is true, timestamp will be
// displayed in "%H:%M:%S" format.
// Collects one sample of monitoring profile and shows user-friendly metrics.
// If timestamp flag is true, timestamp will be displayed in "%H:%M:%S" format.
Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
int monitoring_level, bool timestamp, int num_queries);
int monitoring_level, bool timestamp, string* output);
// Starts tracing on a single or multiple hosts and saves the result in the
// given logdir. If no trace was collected, retries tracing for