Add a StartMonitoring Python API.

PiperOrigin-RevId: 253079116
This commit is contained in:
A. Unique TensorFlower 2019-06-13 12:08:28 -07:00 committed by TensorFlower Gardener
parent b6ccdad4d4
commit 514004a234
8 changed files with 78 additions and 13 deletions

View File

@ -102,6 +102,28 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
return s.ok();
}
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
TF_Buffer* result, TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return;
}
string content;
s = tensorflow::profiler::client::Monitor(
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
result->data = data;
result->length = content.length();
result->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
tensorflow::Set_TF_Status_from_Status(status, s);
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);

View File

@ -88,6 +88,16 @@ TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
TF_Status* status);
// Send a grpc request to profiler server (service_addr) to perform on-demand
// monitoring and return the result in a string. It will block the
// caller thread until receiving the monitoring result.
// This API is designed for TensorBoard, for end user, please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
const char* service_addr, int duration_ms, int monitoring_level,
bool display_timestamp, TF_Buffer* result, TF_Status* status);
// TODO(fishx): Move these monitoring APIs into a separate file.
// -----------------------------------------------------------------------------
// Monitoring Counter APIs.

View File

@ -29,11 +29,13 @@ namespace tensorflow {
namespace {
Status MonitoringHelper(const string& service_addr, int duration_ms,
int monitoring_level, bool timestamp, int num_queries) {
int monitoring_level, bool display_timestamp,
int num_queries) {
for (int query = 0; query < num_queries; ++query) {
string result;
TF_RETURN_IF_ERROR(tensorflow::profiler::client::StartMonitoring(
service_addr, duration_ms, monitoring_level, timestamp, &result));
TF_RETURN_IF_ERROR(tensorflow::profiler::client::Monitor(
service_addr, duration_ms, monitoring_level, display_timestamp,
&result));
std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
<< "):\n\n"
<< result << std::flush;
@ -52,7 +54,7 @@ int main(int argc, char** argv) {
int FLAGS_num_tracing_attempts = 3;
bool FLAGS_include_dataset_ops = true;
int FLAGS_monitoring_level = 0;
bool FLAGS_timestamp = false;
bool FLAGS_display_timestamp = false;
int FLAGS_num_queries = 100;
std::vector<tensorflow::Flag> flag_list = {
tensorflow::Flag("service_addr", &FLAGS_service_addr,
@ -76,7 +78,7 @@ int main(int argc, char** argv) {
"Choose a monitoring level between 1 and 2 to monitor "
"your TPU job continuously. Level 2 is more verbose "
"than level 1 and shows more metrics."),
tensorflow::Flag("timestamp", &FLAGS_timestamp,
tensorflow::Flag("display_timestamp", &FLAGS_display_timestamp,
"Set to true to display timestamp in monitoring "
"results."),
tensorflow::Flag("num_queries", &FLAGS_num_queries,
@ -128,7 +130,7 @@ int main(int argc, char** argv) {
<< std::endl;
status = tensorflow::MonitoringHelper(FLAGS_service_addr, duration_ms,
FLAGS_monitoring_level,
FLAGS_timestamp, num_queries);
FLAGS_display_timestamp, num_queries);
} else {
status = tensorflow::profiler::client::StartTracing(
FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list,

View File

@ -244,10 +244,10 @@ MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level,
return request;
}
Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
int monitoring_level, bool timestamp, string* result) {
Status Monitor(const tensorflow::string& service_addr, int duration_ms,
int monitoring_level, bool display_timestamp, string* result) {
MonitorRequest request =
PopulateMonitorRequest(duration_ms, monitoring_level, timestamp);
PopulateMonitorRequest(duration_ms, monitoring_level, display_timestamp);
::grpc::ClientContext context;
::grpc::ChannelArguments channel_args;

View File

@ -27,8 +27,8 @@ Status ValidateHostPortPair(const string& host_port);
// Collects one sample of monitoring profile and shows user-friendly metrics.
// If timestamp flag is true, timestamp will be displayed in "%H:%M:%S" format.
Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
int monitoring_level, bool timestamp, string* output);
Status Monitor(const tensorflow::string& service_addr, int duration_ms,
int monitoring_level, bool display_timestamp, string* result);
// Starts tracing on a single or multiple hosts and saves the result in the
// given logdir. If no trace was collected, retries tracing for

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors
@ -28,9 +29,9 @@ def start_tracing(service_addr,
worker_list='',
include_dataset_ops=True,
num_tracing_attempts=3):
"""Sending grpc requests to profiler server to perform on-demand profiling.
"""Sends grpc requests to profiler server to perform on-demand profiling.
Note: This method will block caller thread until receives tracing result.
This method will block caller thread until receives tracing result.
Args:
service_addr: Address of profiler service e.g. localhost:6009.
@ -49,3 +50,28 @@ def start_tracing(service_addr,
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
num_tracing_attempts):
raise errors.UnavailableError(None, None, 'No trace event is collected.')
def monitor(service_addr,
duration_ms,
monitoring_level=1,
display_timestamp=False):
"""Sends grpc requests to profiler server to perform on-demand monitoring.
This method will block caller thread until receives monitoring result.
Args:
service_addr: Address of profiler service e.g. localhost:6009.
duration_ms: Duration of tracing or monitoring in ms.
monitoring_level: Choose a monitoring level between 1 and 2 to monitor your
job. Level 2 is more verbose than level 1 and shows more metrics.
display_timestamp: Set to true to display timestamp in monitoring result.
Returns:
A string of monitoring output.
"""
with c_api_util.tf_buffer() as buffer_:
pywrap_tensorflow.TFE_ProfilerClientMonitor(service_addr, duration_ms,
monitoring_level,
display_timestamp, buffer_)
return pywrap_tensorflow.TF_GetBuffer(buffer_)

View File

@ -30,6 +30,10 @@ class ProfilerClientTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.UnavailableError):
profiler_client.start_tracing('localhost:6006', '/tmp/', 2000)
def testMonitor_ProcessInvalidAddress(self):
with self.assertRaises(errors.UnavailableError):
profiler_client.monitor('localhost:6006', 2000)
if __name__ == '__main__':
test.main()

View File

@ -52,6 +52,7 @@ limitations under the License.
%rename("%s") TFE_DeleteProfilerContext;
%rename("%s") TFE_StartProfilerServer;
%rename("%s") TFE_ProfilerClientStartTracing;
%rename("%s") TFE_ProfilerClientMonitor;
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;
%rename("%s") TFE_Py_SetEagerTensorProfiler;