Added timestamp flag in TensorFlow TPU profiler for monitoring.

PiperOrigin-RevId: 243116546
This commit is contained in:
A. Unique TensorFlower 2019-04-11 12:12:38 -07:00 committed by TensorFlower Gardener
parent ceeea9c916
commit 9a4141c306
3 changed files with 11 additions and 6 deletions

View File

@ -31,6 +31,7 @@ int main(int argc, char** argv) {
int FLAGS_num_tracing_attempts = 3; int FLAGS_num_tracing_attempts = 3;
bool FLAGS_include_dataset_ops = true; bool FLAGS_include_dataset_ops = true;
int FLAGS_monitoring_level = 0; int FLAGS_monitoring_level = 0;
bool FLAGS_timestamp = false;
int FLAGS_num_queries = 100; int FLAGS_num_queries = 100;
std::vector<tensorflow::Flag> flag_list = { std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("service_addr", &FLAGS_service_addr, tensorflow::Flag("service_addr", &FLAGS_service_addr,
@ -54,6 +55,9 @@ int main(int argc, char** argv) {
"Choose a monitoring level between 1 and 2 to monitor " "Choose a monitoring level between 1 and 2 to monitor "
"your TPU job continuously. Level 2 is more verbose " "your TPU job continuously. Level 2 is more verbose "
"than level 1 and shows more metrics."), "than level 1 and shows more metrics."),
tensorflow::Flag("timestamp", &FLAGS_timestamp,
"Set to true to display timestamp in monitoring "
"results."),
tensorflow::Flag("num_queries", &FLAGS_num_queries, tensorflow::Flag("num_queries", &FLAGS_num_queries,
"This script will run monitoring for num_queries before " "This script will run monitoring for num_queries before "
"it stops.")}; "it stops.")};
@ -102,7 +106,8 @@ int main(int argc, char** argv) {
<< "ms and show metrics for " << num_queries << " time(s)." << "ms and show metrics for " << num_queries << " time(s)."
<< std::endl; << std::endl;
tensorflow::profiler::client::StartMonitoring( tensorflow::profiler::client::StartMonitoring(
FLAGS_service_addr, duration_ms, FLAGS_monitoring_level, num_queries); FLAGS_service_addr, duration_ms, FLAGS_monitoring_level,
FLAGS_timestamp, num_queries);
} else { } else {
status = tensorflow::profiler::client::StartTracing( status = tensorflow::profiler::client::StartTracing(
FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list, FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list,

View File

@ -243,11 +243,10 @@ MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level,
} }
void StartMonitoring(const tensorflow::string& service_addr, int duration_ms, void StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
int monitoring_level, int num_queries) { int monitoring_level, bool timestamp, int num_queries) {
for (int query = 0; query < num_queries; ++query) { for (int query = 0; query < num_queries; ++query) {
MonitorRequest request = MonitorRequest request =
PopulateMonitorRequest(duration_ms, monitoring_level, PopulateMonitorRequest(duration_ms, monitoring_level, timestamp);
/*timestamp=*/false);
::grpc::ClientContext context; ::grpc::ClientContext context;
::grpc::ChannelArguments channel_args; ::grpc::ChannelArguments channel_args;

View File

@ -26,9 +26,10 @@ namespace client {
Status ValidateHostPortPair(const string& host_port); Status ValidateHostPortPair(const string& host_port);
// Repeatedly collects profiles and shows user-friendly metrics for // Repeatedly collects profiles and shows user-friendly metrics for
// 'num_queries' time(s). // 'num_queries' time(s). If timestamp flag is true, timestamp will be
// displayed in "%H:%M:%S" format.
void StartMonitoring(const tensorflow::string& service_addr, int duration_ms, void StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
int monitoring_level, int num_queries); int monitoring_level, bool timestamp, int num_queries);
// Starts tracing on a single or multiple hosts and saves the result in the // Starts tracing on a single or multiple hosts and saves the result in the
// given logdir. If no trace was collected, retries tracing for // given logdir. If no trace was collected, retries tracing for