Make StartMonitoring return status.

PiperOrigin-RevId: 250987885
This commit is contained in:
A. Unique TensorFlower 2019-05-31 17:36:06 -07:00 committed by TensorFlower Gardener
parent 86826c8ae8
commit 889bf5a55c
3 changed files with 17 additions and 13 deletions

View File

@ -105,18 +105,19 @@ int main(int argc, char** argv) {
<< FLAGS_service_addr << " for " << duration_ms
<< "ms and show metrics for " << num_queries << " time(s)."
<< std::endl;
tensorflow::profiler::client::StartMonitoring(
status = tensorflow::profiler::client::StartMonitoring(
FLAGS_service_addr, duration_ms, FLAGS_monitoring_level,
FLAGS_timestamp, num_queries);
} else {
status = tensorflow::profiler::client::StartTracing(
FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list,
FLAGS_include_dataset_ops, duration_ms, num_tracing_attempts);
if (!status.ok() && status.code() != tensorflow::error::Code::UNAVAILABLE) {
std::cout << status.error_message() << std::endl;
std::cout << usage.c_str() << std::endl;
return 2;
}
}
if (!status.ok() && status.code() != tensorflow::error::Code::UNAVAILABLE) {
std::cout << status.error_message() << std::endl;
std::cout << usage.c_str() << std::endl;
return 2;
}
return 0;
}

View File

@ -206,7 +206,7 @@ Status StartTracing(const tensorflow::string& service_addr,
opts.set_include_dataset_ops(include_dataset_ops);
while (true) {
std::cout << "Starting to profile TPU traces for " << duration_ms << " ms. "
<< "Remaining attempt(s): " << remaining_attempts-- << std::endl;
<< "Remaining attempt(s): " << --remaining_attempts << std::endl;
if (hostnames.empty()) {
status = Profile(service_addr, logdir, duration_ms, repository_root,
session_id, opts);
@ -216,7 +216,8 @@ Status StartTracing(const tensorflow::string& service_addr,
session_id, opts);
}
if (remaining_attempts <= 0 || status.ok() ||
status.code() != tensorflow::error::Code::UNAVAILABLE)
status.code() != tensorflow::error::Code::UNAVAILABLE ||
status.code() != tensorflow::error::Code::ALREADY_EXISTS)
break;
std::cout << "No trace event is collected. Automatically retrying."
<< std::endl
@ -243,8 +244,8 @@ MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level,
return request;
}
void StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
int monitoring_level, bool timestamp, int num_queries) {
Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
int monitoring_level, bool timestamp, int num_queries) {
for (int query = 0; query < num_queries; ++query) {
MonitorRequest request =
PopulateMonitorRequest(duration_ms, monitoring_level, timestamp);
@ -258,12 +259,14 @@ void StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
"dns:///" + service_addr, ::grpc::InsecureChannelCredentials(),
channel_args));
MonitorResponse response;
TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response)));
TF_RETURN_IF_ERROR(
FromGrpcStatus(stub->Monitor(&context, request, &response)));
std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
<< "):\n\n"
<< response.data() << std::flush;
}
return Status::OK();
}
} // namespace client

View File

@ -28,8 +28,8 @@ Status ValidateHostPortPair(const string& host_port);
// Repeatedly collects profiles and shows user-friendly metrics for
// 'num_queries' time(s). If timestamp flag is true, timestamp will be
// displayed in "%H:%M:%S" format.
void StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
int monitoring_level, bool timestamp, int num_queries);
Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
int monitoring_level, bool timestamp, int num_queries);
// Starts tracing on a single or multiple hosts and saves the result in the
// given logdir. If no trace was collected, retries tracing for