Refactor StartMonitoring to stream back result.
PiperOrigin-RevId: 252503898
This commit is contained in:
parent
d4391586b5
commit
19292fe67f
@ -3,7 +3,6 @@ package(
|
||||
)
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
|
||||
|
||||
cc_library(
|
||||
name = "version",
|
||||
|
@ -19,10 +19,31 @@ limitations under the License.
|
||||
// receives and dumps the profile data to a tensorboard log directory.
|
||||
|
||||
#include "tensorflow/contrib/tpu/profiler/version.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/platform/init_main.h"
|
||||
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status MonitoringHelper(const string& service_addr, int duration_ms,
|
||||
int monitoring_level, bool timestamp, int num_queries) {
|
||||
for (int query = 0; query < num_queries; ++query) {
|
||||
string result;
|
||||
TF_RETURN_IF_ERROR(tensorflow::profiler::client::StartMonitoring(
|
||||
service_addr, duration_ms, monitoring_level, timestamp, &result));
|
||||
std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
|
||||
<< "):\n\n"
|
||||
<< result << std::flush;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::string FLAGS_service_addr;
|
||||
tensorflow::string FLAGS_logdir;
|
||||
@ -105,9 +126,9 @@ int main(int argc, char** argv) {
|
||||
<< FLAGS_service_addr << " for " << duration_ms
|
||||
<< "ms and show metrics for " << num_queries << " time(s)."
|
||||
<< std::endl;
|
||||
status = tensorflow::profiler::client::StartMonitoring(
|
||||
FLAGS_service_addr, duration_ms, FLAGS_monitoring_level,
|
||||
FLAGS_timestamp, num_queries);
|
||||
status = tensorflow::MonitoringHelper(FLAGS_service_addr, duration_ms,
|
||||
FLAGS_monitoring_level,
|
||||
FLAGS_timestamp, num_queries);
|
||||
} else {
|
||||
status = tensorflow::profiler::client::StartTracing(
|
||||
FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list,
|
||||
|
@ -245,27 +245,22 @@ MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level,
|
||||
}
|
||||
|
||||
Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
|
||||
int monitoring_level, bool timestamp, int num_queries) {
|
||||
for (int query = 0; query < num_queries; ++query) {
|
||||
MonitorRequest request =
|
||||
PopulateMonitorRequest(duration_ms, monitoring_level, timestamp);
|
||||
int monitoring_level, bool timestamp, string* result) {
|
||||
MonitorRequest request =
|
||||
PopulateMonitorRequest(duration_ms, monitoring_level, timestamp);
|
||||
|
||||
::grpc::ClientContext context;
|
||||
::grpc::ChannelArguments channel_args;
|
||||
channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
|
||||
std::numeric_limits<int32>::max());
|
||||
std::unique_ptr<grpc::ProfilerService::Stub> stub =
|
||||
grpc::ProfilerService::NewStub(::grpc::CreateCustomChannel(
|
||||
"dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
|
||||
channel_args));
|
||||
MonitorResponse response;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FromGrpcStatus(stub->Monitor(&context, request, &response)));
|
||||
|
||||
std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
|
||||
<< "):\n\n"
|
||||
<< response.data() << std::flush;
|
||||
}
|
||||
::grpc::ClientContext context;
|
||||
::grpc::ChannelArguments channel_args;
|
||||
channel_args.SetInt(GRPC_ARG_MAX_MESSAGE_LENGTH,
|
||||
std::numeric_limits<int32>::max());
|
||||
std::unique_ptr<grpc::ProfilerService::Stub> stub =
|
||||
grpc::ProfilerService::NewStub(::grpc::CreateCustomChannel(
|
||||
"dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
|
||||
channel_args));
|
||||
MonitorResponse response;
|
||||
TF_RETURN_IF_ERROR(
|
||||
FromGrpcStatus(stub->Monitor(&context, request, &response)));
|
||||
*result = response.data();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -25,11 +25,10 @@ namespace client {
|
||||
|
||||
Status ValidateHostPortPair(const string& host_port);
|
||||
|
||||
// Repeatedly collects profiles and shows user-friendly metrics for
|
||||
// 'num_queries' time(s). If timestamp flag is true, timestamp will be
|
||||
// displayed in "%H:%M:%S" format.
|
||||
// Collects one sample of monitoring profile and shows user-friendly metrics.
|
||||
// If timestamp flag is true, timestamp will be displayed in "%H:%M:%S" format.
|
||||
Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
|
||||
int monitoring_level, bool timestamp, int num_queries);
|
||||
int monitoring_level, bool timestamp, string* output);
|
||||
|
||||
// Starts tracing on a single or multiple hosts and saves the result in the
|
||||
// given logdir. If no trace was collected, retries tracing for
|
||||
|
Loading…
x
Reference in New Issue
Block a user