Add a StartMonitoring Python API.
PiperOrigin-RevId: 253079116
This commit is contained in:
parent
b6ccdad4d4
commit
514004a234
@ -102,6 +102,28 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
|||||||
return s.ok();
|
return s.ok();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
|
||||||
|
int monitoring_level, bool display_timestamp,
|
||||||
|
TF_Buffer* result, TF_Status* status) {
|
||||||
|
tensorflow::Status s =
|
||||||
|
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
|
||||||
|
if (!s.ok()) {
|
||||||
|
Set_TF_Status_from_Status(status, s);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
string content;
|
||||||
|
s = tensorflow::profiler::client::Monitor(
|
||||||
|
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
|
||||||
|
void* data = tensorflow::port::Malloc(content.length());
|
||||||
|
content.copy(static_cast<char*>(data), content.length(), 0);
|
||||||
|
result->data = data;
|
||||||
|
result->length = content.length();
|
||||||
|
result->data_deallocator = [](void* data, size_t length) {
|
||||||
|
tensorflow::port::Free(data);
|
||||||
|
};
|
||||||
|
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||||
|
}
|
||||||
|
|
||||||
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
||||||
int64_t value) {
|
int64_t value) {
|
||||||
cell->cell.IncrementBy(value);
|
cell->cell.IncrementBy(value);
|
||||||
|
@ -88,6 +88,16 @@ TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
|
|||||||
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
|
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
|
||||||
TF_Status* status);
|
TF_Status* status);
|
||||||
|
|
||||||
|
// Send a grpc request to profiler server (service_addr) to perform on-demand
|
||||||
|
// monitoring and return the result in a string. It will block the
|
||||||
|
// caller thread until receiving the monitoring result.
|
||||||
|
// This API is designed for TensorBoard, for end user, please use
|
||||||
|
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
|
||||||
|
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
||||||
|
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
|
||||||
|
const char* service_addr, int duration_ms, int monitoring_level,
|
||||||
|
bool display_timestamp, TF_Buffer* result, TF_Status* status);
|
||||||
|
|
||||||
// TODO(fishx): Move these monitoring APIs into a separate file.
|
// TODO(fishx): Move these monitoring APIs into a separate file.
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Monitoring Counter APIs.
|
// Monitoring Counter APIs.
|
||||||
|
@ -29,11 +29,13 @@ namespace tensorflow {
|
|||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
Status MonitoringHelper(const string& service_addr, int duration_ms,
|
Status MonitoringHelper(const string& service_addr, int duration_ms,
|
||||||
int monitoring_level, bool timestamp, int num_queries) {
|
int monitoring_level, bool display_timestamp,
|
||||||
|
int num_queries) {
|
||||||
for (int query = 0; query < num_queries; ++query) {
|
for (int query = 0; query < num_queries; ++query) {
|
||||||
string result;
|
string result;
|
||||||
TF_RETURN_IF_ERROR(tensorflow::profiler::client::StartMonitoring(
|
TF_RETURN_IF_ERROR(tensorflow::profiler::client::Monitor(
|
||||||
service_addr, duration_ms, monitoring_level, timestamp, &result));
|
service_addr, duration_ms, monitoring_level, display_timestamp,
|
||||||
|
&result));
|
||||||
std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
|
std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
|
||||||
<< "):\n\n"
|
<< "):\n\n"
|
||||||
<< result << std::flush;
|
<< result << std::flush;
|
||||||
@ -52,7 +54,7 @@ int main(int argc, char** argv) {
|
|||||||
int FLAGS_num_tracing_attempts = 3;
|
int FLAGS_num_tracing_attempts = 3;
|
||||||
bool FLAGS_include_dataset_ops = true;
|
bool FLAGS_include_dataset_ops = true;
|
||||||
int FLAGS_monitoring_level = 0;
|
int FLAGS_monitoring_level = 0;
|
||||||
bool FLAGS_timestamp = false;
|
bool FLAGS_display_timestamp = false;
|
||||||
int FLAGS_num_queries = 100;
|
int FLAGS_num_queries = 100;
|
||||||
std::vector<tensorflow::Flag> flag_list = {
|
std::vector<tensorflow::Flag> flag_list = {
|
||||||
tensorflow::Flag("service_addr", &FLAGS_service_addr,
|
tensorflow::Flag("service_addr", &FLAGS_service_addr,
|
||||||
@ -76,7 +78,7 @@ int main(int argc, char** argv) {
|
|||||||
"Choose a monitoring level between 1 and 2 to monitor "
|
"Choose a monitoring level between 1 and 2 to monitor "
|
||||||
"your TPU job continuously. Level 2 is more verbose "
|
"your TPU job continuously. Level 2 is more verbose "
|
||||||
"than level 1 and shows more metrics."),
|
"than level 1 and shows more metrics."),
|
||||||
tensorflow::Flag("timestamp", &FLAGS_timestamp,
|
tensorflow::Flag("display_timestamp", &FLAGS_display_timestamp,
|
||||||
"Set to true to display timestamp in monitoring "
|
"Set to true to display timestamp in monitoring "
|
||||||
"results."),
|
"results."),
|
||||||
tensorflow::Flag("num_queries", &FLAGS_num_queries,
|
tensorflow::Flag("num_queries", &FLAGS_num_queries,
|
||||||
@ -128,7 +130,7 @@ int main(int argc, char** argv) {
|
|||||||
<< std::endl;
|
<< std::endl;
|
||||||
status = tensorflow::MonitoringHelper(FLAGS_service_addr, duration_ms,
|
status = tensorflow::MonitoringHelper(FLAGS_service_addr, duration_ms,
|
||||||
FLAGS_monitoring_level,
|
FLAGS_monitoring_level,
|
||||||
FLAGS_timestamp, num_queries);
|
FLAGS_display_timestamp, num_queries);
|
||||||
} else {
|
} else {
|
||||||
status = tensorflow::profiler::client::StartTracing(
|
status = tensorflow::profiler::client::StartTracing(
|
||||||
FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list,
|
FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list,
|
||||||
|
@ -244,10 +244,10 @@ MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level,
|
|||||||
return request;
|
return request;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
|
Status Monitor(const tensorflow::string& service_addr, int duration_ms,
|
||||||
int monitoring_level, bool timestamp, string* result) {
|
int monitoring_level, bool display_timestamp, string* result) {
|
||||||
MonitorRequest request =
|
MonitorRequest request =
|
||||||
PopulateMonitorRequest(duration_ms, monitoring_level, timestamp);
|
PopulateMonitorRequest(duration_ms, monitoring_level, display_timestamp);
|
||||||
|
|
||||||
::grpc::ClientContext context;
|
::grpc::ClientContext context;
|
||||||
::grpc::ChannelArguments channel_args;
|
::grpc::ChannelArguments channel_args;
|
||||||
|
@ -27,8 +27,8 @@ Status ValidateHostPortPair(const string& host_port);
|
|||||||
|
|
||||||
// Collects one sample of monitoring profile and shows user-friendly metrics.
|
// Collects one sample of monitoring profile and shows user-friendly metrics.
|
||||||
// If timestamp flag is true, timestamp will be displayed in "%H:%M:%S" format.
|
// If timestamp flag is true, timestamp will be displayed in "%H:%M:%S" format.
|
||||||
Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
|
Status Monitor(const tensorflow::string& service_addr, int duration_ms,
|
||||||
int monitoring_level, bool timestamp, string* output);
|
int monitoring_level, bool display_timestamp, string* result);
|
||||||
|
|
||||||
// Starts tracing on a single or multiple hosts and saves the result in the
|
// Starts tracing on a single or multiple hosts and saves the result in the
|
||||||
// given logdir. If no trace was collected, retries tracing for
|
// given logdir. If no trace was collected, retries tracing for
|
||||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.python import pywrap_tensorflow
|
from tensorflow.python import pywrap_tensorflow
|
||||||
|
from tensorflow.python.framework import c_api_util
|
||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
|
|
||||||
|
|
||||||
@ -28,9 +29,9 @@ def start_tracing(service_addr,
|
|||||||
worker_list='',
|
worker_list='',
|
||||||
include_dataset_ops=True,
|
include_dataset_ops=True,
|
||||||
num_tracing_attempts=3):
|
num_tracing_attempts=3):
|
||||||
"""Sending grpc requests to profiler server to perform on-demand profiling.
|
"""Sends grpc requests to profiler server to perform on-demand profiling.
|
||||||
|
|
||||||
Note: This method will block caller thread until receives tracing result.
|
This method will block caller thread until receives tracing result.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
service_addr: Address of profiler service e.g. localhost:6009.
|
service_addr: Address of profiler service e.g. localhost:6009.
|
||||||
@ -49,3 +50,28 @@ def start_tracing(service_addr,
|
|||||||
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
|
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
|
||||||
num_tracing_attempts):
|
num_tracing_attempts):
|
||||||
raise errors.UnavailableError(None, None, 'No trace event is collected.')
|
raise errors.UnavailableError(None, None, 'No trace event is collected.')
|
||||||
|
|
||||||
|
|
||||||
|
def monitor(service_addr,
|
||||||
|
duration_ms,
|
||||||
|
monitoring_level=1,
|
||||||
|
display_timestamp=False):
|
||||||
|
"""Sends grpc requests to profiler server to perform on-demand monitoring.
|
||||||
|
|
||||||
|
This method will block caller thread until receives monitoring result.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
service_addr: Address of profiler service e.g. localhost:6009.
|
||||||
|
duration_ms: Duration of tracing or monitoring in ms.
|
||||||
|
monitoring_level: Choose a monitoring level between 1 and 2 to monitor your
|
||||||
|
job. Level 2 is more verbose than level 1 and shows more metrics.
|
||||||
|
display_timestamp: Set to true to display timestamp in monitoring result.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A string of monitoring output.
|
||||||
|
"""
|
||||||
|
with c_api_util.tf_buffer() as buffer_:
|
||||||
|
pywrap_tensorflow.TFE_ProfilerClientMonitor(service_addr, duration_ms,
|
||||||
|
monitoring_level,
|
||||||
|
display_timestamp, buffer_)
|
||||||
|
return pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||||
|
@ -30,6 +30,10 @@ class ProfilerClientTest(test_util.TensorFlowTestCase):
|
|||||||
with self.assertRaises(errors.UnavailableError):
|
with self.assertRaises(errors.UnavailableError):
|
||||||
profiler_client.start_tracing('localhost:6006', '/tmp/', 2000)
|
profiler_client.start_tracing('localhost:6006', '/tmp/', 2000)
|
||||||
|
|
||||||
|
def testMonitor_ProcessInvalidAddress(self):
|
||||||
|
with self.assertRaises(errors.UnavailableError):
|
||||||
|
profiler_client.monitor('localhost:6006', 2000)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -52,6 +52,7 @@ limitations under the License.
|
|||||||
%rename("%s") TFE_DeleteProfilerContext;
|
%rename("%s") TFE_DeleteProfilerContext;
|
||||||
%rename("%s") TFE_StartProfilerServer;
|
%rename("%s") TFE_StartProfilerServer;
|
||||||
%rename("%s") TFE_ProfilerClientStartTracing;
|
%rename("%s") TFE_ProfilerClientStartTracing;
|
||||||
|
%rename("%s") TFE_ProfilerClientMonitor;
|
||||||
%rename("%s") TFE_OpNameGetAttrType;
|
%rename("%s") TFE_OpNameGetAttrType;
|
||||||
%rename("%s") TFE_Py_InitEagerTensor;
|
%rename("%s") TFE_Py_InitEagerTensor;
|
||||||
%rename("%s") TFE_Py_SetEagerTensorProfiler;
|
%rename("%s") TFE_Py_SetEagerTensorProfiler;
|
||||||
|
Loading…
Reference in New Issue
Block a user