Add a StartMonitoring Python API.
PiperOrigin-RevId: 253079116
This commit is contained in:
parent
b6ccdad4d4
commit
514004a234
tensorflow
c/eager
contrib/tpu/profiler
core/profiler/rpc/client
python
@ -102,6 +102,28 @@ bool TFE_ProfilerClientStartTracing(const char* service_addr,
|
||||
return s.ok();
|
||||
}
|
||||
|
||||
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
|
||||
int monitoring_level, bool display_timestamp,
|
||||
TF_Buffer* result, TF_Status* status) {
|
||||
tensorflow::Status s =
|
||||
tensorflow::profiler::client::ValidateHostPortPair(service_addr);
|
||||
if (!s.ok()) {
|
||||
Set_TF_Status_from_Status(status, s);
|
||||
return;
|
||||
}
|
||||
string content;
|
||||
s = tensorflow::profiler::client::Monitor(
|
||||
service_addr, duration_ms, monitoring_level, display_timestamp, &content);
|
||||
void* data = tensorflow::port::Malloc(content.length());
|
||||
content.copy(static_cast<char*>(data), content.length(), 0);
|
||||
result->data = data;
|
||||
result->length = content.length();
|
||||
result->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
tensorflow::Set_TF_Status_from_Status(status, s);
|
||||
}
|
||||
|
||||
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
|
||||
int64_t value) {
|
||||
cell->cell.IncrementBy(value);
|
||||
|
@ -88,6 +88,16 @@ TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
|
||||
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
|
||||
TF_Status* status);
|
||||
|
||||
// Send a grpc request to profiler server (service_addr) to perform on-demand
|
||||
// monitoring and return the result in a string. It will block the
|
||||
// caller thread until receiving the monitoring result.
|
||||
// This API is designed for TensorBoard, for end user, please use
|
||||
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
|
||||
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
|
||||
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
|
||||
const char* service_addr, int duration_ms, int monitoring_level,
|
||||
bool display_timestamp, TF_Buffer* result, TF_Status* status);
|
||||
|
||||
// TODO(fishx): Move these monitoring APIs into a separate file.
|
||||
// -----------------------------------------------------------------------------
|
||||
// Monitoring Counter APIs.
|
||||
|
@ -29,11 +29,13 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
Status MonitoringHelper(const string& service_addr, int duration_ms,
|
||||
int monitoring_level, bool timestamp, int num_queries) {
|
||||
int monitoring_level, bool display_timestamp,
|
||||
int num_queries) {
|
||||
for (int query = 0; query < num_queries; ++query) {
|
||||
string result;
|
||||
TF_RETURN_IF_ERROR(tensorflow::profiler::client::StartMonitoring(
|
||||
service_addr, duration_ms, monitoring_level, timestamp, &result));
|
||||
TF_RETURN_IF_ERROR(tensorflow::profiler::client::Monitor(
|
||||
service_addr, duration_ms, monitoring_level, display_timestamp,
|
||||
&result));
|
||||
std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
|
||||
<< "):\n\n"
|
||||
<< result << std::flush;
|
||||
@ -52,7 +54,7 @@ int main(int argc, char** argv) {
|
||||
int FLAGS_num_tracing_attempts = 3;
|
||||
bool FLAGS_include_dataset_ops = true;
|
||||
int FLAGS_monitoring_level = 0;
|
||||
bool FLAGS_timestamp = false;
|
||||
bool FLAGS_display_timestamp = false;
|
||||
int FLAGS_num_queries = 100;
|
||||
std::vector<tensorflow::Flag> flag_list = {
|
||||
tensorflow::Flag("service_addr", &FLAGS_service_addr,
|
||||
@ -76,7 +78,7 @@ int main(int argc, char** argv) {
|
||||
"Choose a monitoring level between 1 and 2 to monitor "
|
||||
"your TPU job continuously. Level 2 is more verbose "
|
||||
"than level 1 and shows more metrics."),
|
||||
tensorflow::Flag("timestamp", &FLAGS_timestamp,
|
||||
tensorflow::Flag("display_timestamp", &FLAGS_display_timestamp,
|
||||
"Set to true to display timestamp in monitoring "
|
||||
"results."),
|
||||
tensorflow::Flag("num_queries", &FLAGS_num_queries,
|
||||
@ -128,7 +130,7 @@ int main(int argc, char** argv) {
|
||||
<< std::endl;
|
||||
status = tensorflow::MonitoringHelper(FLAGS_service_addr, duration_ms,
|
||||
FLAGS_monitoring_level,
|
||||
FLAGS_timestamp, num_queries);
|
||||
FLAGS_display_timestamp, num_queries);
|
||||
} else {
|
||||
status = tensorflow::profiler::client::StartTracing(
|
||||
FLAGS_service_addr, FLAGS_logdir, FLAGS_workers_list,
|
||||
|
@ -244,10 +244,10 @@ MonitorRequest PopulateMonitorRequest(int duration_ms, int monitoring_level,
|
||||
return request;
|
||||
}
|
||||
|
||||
Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
|
||||
int monitoring_level, bool timestamp, string* result) {
|
||||
Status Monitor(const tensorflow::string& service_addr, int duration_ms,
|
||||
int monitoring_level, bool display_timestamp, string* result) {
|
||||
MonitorRequest request =
|
||||
PopulateMonitorRequest(duration_ms, monitoring_level, timestamp);
|
||||
PopulateMonitorRequest(duration_ms, monitoring_level, display_timestamp);
|
||||
|
||||
::grpc::ClientContext context;
|
||||
::grpc::ChannelArguments channel_args;
|
||||
|
@ -27,8 +27,8 @@ Status ValidateHostPortPair(const string& host_port);
|
||||
|
||||
// Collects one sample of monitoring profile and shows user-friendly metrics.
|
||||
// If timestamp flag is true, timestamp will be displayed in "%H:%M:%S" format.
|
||||
Status StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
|
||||
int monitoring_level, bool timestamp, string* output);
|
||||
Status Monitor(const tensorflow::string& service_addr, int duration_ms,
|
||||
int monitoring_level, bool display_timestamp, string* result);
|
||||
|
||||
// Starts tracing on a single or multiple hosts and saves the result in the
|
||||
// given logdir. If no trace was collected, retries tracing for
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.python import pywrap_tensorflow
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import errors
|
||||
|
||||
|
||||
@ -28,9 +29,9 @@ def start_tracing(service_addr,
|
||||
worker_list='',
|
||||
include_dataset_ops=True,
|
||||
num_tracing_attempts=3):
|
||||
"""Sending grpc requests to profiler server to perform on-demand profiling.
|
||||
"""Sends grpc requests to profiler server to perform on-demand profiling.
|
||||
|
||||
Note: This method will block caller thread until receives tracing result.
|
||||
This method will block caller thread until receives tracing result.
|
||||
|
||||
Args:
|
||||
service_addr: Address of profiler service e.g. localhost:6009.
|
||||
@ -49,3 +50,28 @@ def start_tracing(service_addr,
|
||||
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
|
||||
num_tracing_attempts):
|
||||
raise errors.UnavailableError(None, None, 'No trace event is collected.')
|
||||
|
||||
|
||||
def monitor(service_addr,
|
||||
duration_ms,
|
||||
monitoring_level=1,
|
||||
display_timestamp=False):
|
||||
"""Sends grpc requests to profiler server to perform on-demand monitoring.
|
||||
|
||||
This method will block caller thread until receives monitoring result.
|
||||
|
||||
Args:
|
||||
service_addr: Address of profiler service e.g. localhost:6009.
|
||||
duration_ms: Duration of tracing or monitoring in ms.
|
||||
monitoring_level: Choose a monitoring level between 1 and 2 to monitor your
|
||||
job. Level 2 is more verbose than level 1 and shows more metrics.
|
||||
display_timestamp: Set to true to display timestamp in monitoring result.
|
||||
|
||||
Returns:
|
||||
A string of monitoring output.
|
||||
"""
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tensorflow.TFE_ProfilerClientMonitor(service_addr, duration_ms,
|
||||
monitoring_level,
|
||||
display_timestamp, buffer_)
|
||||
return pywrap_tensorflow.TF_GetBuffer(buffer_)
|
||||
|
@ -30,6 +30,10 @@ class ProfilerClientTest(test_util.TensorFlowTestCase):
|
||||
with self.assertRaises(errors.UnavailableError):
|
||||
profiler_client.start_tracing('localhost:6006', '/tmp/', 2000)
|
||||
|
||||
def testMonitor_ProcessInvalidAddress(self):
|
||||
with self.assertRaises(errors.UnavailableError):
|
||||
profiler_client.monitor('localhost:6006', 2000)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -52,6 +52,7 @@ limitations under the License.
|
||||
%rename("%s") TFE_DeleteProfilerContext;
|
||||
%rename("%s") TFE_StartProfilerServer;
|
||||
%rename("%s") TFE_ProfilerClientStartTracing;
|
||||
%rename("%s") TFE_ProfilerClientMonitor;
|
||||
%rename("%s") TFE_OpNameGetAttrType;
|
||||
%rename("%s") TFE_Py_InitEagerTensor;
|
||||
%rename("%s") TFE_Py_SetEagerTensorProfiler;
|
||||
|
Loading…
Reference in New Issue
Block a user