Move profiler API implementation to _pywrap_profiler

PiperOrigin-RevId: 295240754
Change-Id: I3664efc053696a3c521d18527c04747688cac932
This commit is contained in:
Jose Baiocchi 2020-02-14 15:30:53 -08:00 committed by TensorFlower Gardener
parent 2ed75c82a2
commit 767e4d5dab
15 changed files with 99 additions and 178 deletions

View File

@ -256,8 +256,6 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/profiler/rpc:profiler_server",
"//tensorflow/core/profiler/rpc/client:capture_profile",
"//tensorflow/core:gpu_runtime",
],
alwayslink = 1,

View File

@ -25,8 +25,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/casts.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h"
using tensorflow::string;
@ -47,14 +45,6 @@ void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
->Handle());
}
void TFE_StartProfilerServer(int port) {
auto profiler_server = absl::make_unique<tensorflow::ProfilerServer>();
profiler_server->StartProfilerServer(port);
// Release child server thread intentionally. The child thread can be
// terminated when the main program exits.
profiler_server.release();
}
void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(true);
}
@ -63,46 +53,6 @@ void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
ctx->context->SetShouldStoreGraphs(false);
}
bool TFE_ProfilerClientStartTracing(const char* service_addr,
const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms,
int num_tracing_attempts,
TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return false;
}
s = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
include_dataset_ops, duration_ms,
num_tracing_attempts);
tensorflow::Set_TF_Status_from_Status(status, s);
return s.ok();
}
void TFE_ProfilerClientMonitor(const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp,
TF_Buffer* result, TF_Status* status) {
tensorflow::Status s =
tensorflow::profiler::ValidateHostPortPair(service_addr);
if (!s.ok()) {
Set_TF_Status_from_Status(status, s);
return;
}
string content;
s = tensorflow::profiler::Monitor(service_addr, duration_ms, monitoring_level,
display_timestamp, &content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
result->data = data;
result->length = content.length();
result->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
tensorflow::Set_TF_Status_from_Status(status, s);
}
void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
int64_t value) {
cell->cell.IncrementBy(value);

View File

@ -37,15 +37,6 @@ TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status);
// Start a profiler grpc server which listens to specified port. It will start
// the server on its own thread. It can be shutdown by terminating tensorflow.
// It can be used in both Eager mode and graph mode. Creating multiple profiler
// server is allowed. The service defined in
// tensorflow/contrib/tpu/profiler/tpu_profiler.proto. Please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile to capture trace file
// following https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern void TFE_StartProfilerServer(int port);
// Enables only graph collection in RunMetadata on the functions executed from
// this context.
TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
@ -54,29 +45,6 @@ TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
// this context.
TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
// Send a grpc request to profiler server (service_addr) to perform on-demand
// profiling and save the result into logdir which can be visualized by
// TensorBoard. worker_list is the list of worker TPUs separated by ','. Set
// include_dataset_opts to false to profile longer traces. It will block the
// caller thread until receives tracing result.
// This API is designed for TensorBoard, for end user, please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern bool TFE_ProfilerClientStartTracing(
const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts,
TF_Status* status);
// Send a grpc request to profiler server (service_addr) to perform on-demand
// monitoring and return the result in a string. It will block the
// caller thread until receiving the monitoring result.
// This API is designed for TensorBoard, for end user, please use
// tensorflow/contrib/tpu/profiler/capture_tpu_profile instead following
// https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_trace.
TF_CAPI_EXPORT extern void TFE_ProfilerClientMonitor(
const char* service_addr, int duration_ms, int monitoring_level,
bool display_timestamp, TF_Buffer* result, TF_Status* status);
// TODO(fishx): Move these monitoring APIs into a separate file.
// -----------------------------------------------------------------------------
// Monitoring Counter APIs.

View File

@ -1,5 +1,6 @@
load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
load("//tensorflow:tensorflow.bzl", "if_not_android", "tf_cuda_library")
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper")
package(
default_visibility = [
@ -9,6 +10,15 @@ package(
licenses = ["notice"], # Apache 2.0
)
tf_pybind_cc_library_wrapper(
name = "profiler_session_headers",
visibility = [
"//tensorflow/core/profiler/rpc:__pkg__",
"//tensorflow/python/profiler/internal:__pkg__",
],
deps = [":profiler_session"],
)
cc_library(
name = "profiler_session",
hdrs = ["profiler_session.h"],
@ -65,6 +75,12 @@ tf_cuda_library(
alwayslink = True,
)
tf_pybind_cc_library_wrapper(
name = "traceme_headers",
visibility = ["//tensorflow/python/profiler/internal:__pkg__"],
deps = [":traceme"],
)
cc_library(
name = "traceme",
hdrs = ["traceme.h"],
@ -90,6 +106,12 @@ cc_library(
],
)
tf_pybind_cc_library_wrapper(
name = "scoped_annotation_headers",
visibility = ["//tensorflow/python/profiler/internal:__pkg__"],
deps = [":scoped_annotation"],
)
cc_library(
name = "scoped_annotation",
hdrs = ["scoped_annotation.h"],

View File

@ -6,14 +6,15 @@ cc_library(
name = "profiler_service_impl",
srcs = ["profiler_service_impl.cc"],
hdrs = ["profiler_service_impl.h"],
visibility = ["//visibility:public"],
features = ["-layering_check"],
visibility = ["//tensorflow_serving/model_servers:__pkg__"],
deps = [
"//tensorflow:grpc++",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/profiler:profiler_service_proto_cc",
"//tensorflow/core/profiler/convert:xplane_to_profile_response",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core/profiler/lib:profiler_session_headers",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
@ -24,7 +25,10 @@ cc_library(
name = "profiler_server",
srcs = ["profiler_server.cc"],
hdrs = ["profiler_server.h"],
visibility = ["//visibility:public"],
visibility = [
"//tensorflow/compiler/xla/python:__pkg__",
"//tensorflow/python/profiler/internal:__pkg__",
],
deps = [
":profiler_service_impl",
"//tensorflow:grpc++",

View File

@ -6,7 +6,7 @@ cc_library(
name = "capture_profile",
srcs = ["capture_profile.cc"],
hdrs = ["capture_profile.h"],
visibility = ["//visibility:public"],
visibility = ["//tensorflow/python/profiler/internal:__pkg__"],
deps = [
":save_profile",
"//tensorflow:grpc++",
@ -22,7 +22,7 @@ cc_library(
name = "save_profile",
srcs = ["save_profile.cc"],
hdrs = ["save_profile.h"],
visibility = ["//visibility:public"],
visibility = ["//tensorflow/core/profiler:internal"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",

View File

@ -205,9 +205,8 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":context",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util",
"//tensorflow/python/profiler/internal:_pywrap_profiler_session",
"//tensorflow/python/profiler/internal:_pywrap_profiler",
],
)
@ -232,9 +231,7 @@ py_library(
"//tensorflow/core/profiler:internal",
],
deps = [
"//tensorflow/python:c_api_util",
"//tensorflow/python:pywrap_tf_session",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python/profiler/internal:_pywrap_profiler",
],
)

View File

@ -39,12 +39,11 @@ import os
import threading
from tensorflow.python import _pywrap_events_writer
from tensorflow.python import pywrap_tfe
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler.internal import _pywrap_profiler_session
from tensorflow.python.profiler.internal import _pywrap_profiler
from tensorflow.python.util import compat
_profiler = None
@ -75,7 +74,7 @@ def start():
raise ProfilerAlreadyRunningError('Another profiler is running.')
if context.default_execution_mode == context.EAGER_MODE:
context.ensure_initialized()
_profiler = _pywrap_profiler_session.ProfilerSession()
_profiler = _pywrap_profiler.ProfilerSession()
try:
_profiler.start()
except errors.AlreadyExistsError:
@ -158,7 +157,7 @@ def start_profiler_server(port):
"""
if context.default_execution_mode == context.EAGER_MODE:
context.ensure_initialized()
pywrap_tfe.TFE_StartProfilerServer(port)
_pywrap_profiler.start_profiler_server(port)
class Profiler(object):

View File

@ -18,10 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tfe
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors
from tensorflow.python.profiler.internal import _pywrap_profiler
def start_tracing(service_addr,
@ -47,10 +44,8 @@ def start_tracing(service_addr,
Raises:
UnavailableError: If no trace event is collected.
"""
if not pywrap_tfe.TFE_ProfilerClientStartTracing(
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
num_tracing_attempts):
raise errors.UnavailableError(None, None, 'No trace event is collected.')
_pywrap_profiler.trace(service_addr, logdir, worker_list, include_dataset_ops,
duration_ms, num_tracing_attempts)
def monitor(service_addr,
@ -71,8 +66,5 @@ def monitor(service_addr,
Returns:
A string of monitoring output.
"""
with c_api_util.tf_buffer() as buffer_:
pywrap_tfe.TFE_ProfilerClientMonitor(service_addr, duration_ms,
monitoring_level, display_timestamp,
buffer_)
return pywrap_tf_session.TF_GetBuffer(buffer_)
return _pywrap_profiler.monitor(service_addr, duration_ms, monitoring_level,
display_timestamp)

View File

@ -26,8 +26,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:c_api_util",
"//tensorflow/python:pywrap_tf_session",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python/profiler/internal:_pywrap_profiler",
],
)

View File

@ -1,8 +1,5 @@
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_pybind_cc_library_wrapper")
# buildifier: disable=same-origin-load
load("//tensorflow:tensorflow.bzl", "tf_python_pybind_extension")
load("//tensorflow:tensorflow.bzl", "py_test")
@ -80,11 +77,6 @@ cuda_py_test(
],
)
tf_pybind_cc_library_wrapper(
name = "traceme_headers",
deps = ["//tensorflow/core/profiler/lib:traceme"],
)
tf_python_pybind_extension(
name = "_pywrap_traceme",
srcs = ["traceme_wrapper.cc"],
@ -95,46 +87,42 @@ tf_python_pybind_extension(
"//tensorflow/python/profiler:__subpackages__",
],
deps = [
":traceme_headers",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:traceme_headers",
"@com_google_absl//absl/types:optional",
"@pybind11",
],
)
tf_pybind_cc_library_wrapper(
name = "scoped_annotation_headers",
deps = ["//tensorflow/core/profiler/lib:scoped_annotation"],
)
tf_python_pybind_extension(
name = "_pywrap_scoped_annotation",
srcs = ["scoped_annotation_wrapper.cc"],
features = ["-layering_check"],
module_name = "_pywrap_scoped_annotation",
deps = [
":scoped_annotation_headers",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:scoped_annotation_headers",
"@com_google_absl//absl/types:optional",
"@pybind11",
],
)
tf_pybind_cc_library_wrapper(
name = "profiler_session_headers",
deps = ["//tensorflow/core/profiler/lib:profiler_session"],
)
tf_python_pybind_extension(
name = "_pywrap_profiler_session",
srcs = ["profiler_session_wrapper.cc"],
name = "_pywrap_profiler",
srcs = ["profiler_wrapper.cc"],
features = ["-layering_check"],
module_name = "_pywrap_profiler_session",
visibility = ["//tensorflow/python/eager:__pkg__"],
module_name = "_pywrap_profiler",
visibility = [
"//tensorflow/python/eager:__pkg__",
"//tensorflow/python/profiler:__pkg__",
],
deps = [
":profiler_session_headers",
"//tensorflow/core:lib",
"//tensorflow/core/profiler/lib:profiler_session_headers",
"//tensorflow/core/profiler/rpc:profiler_server",
"//tensorflow/core/profiler/rpc/client:capture_profile",
"//tensorflow/python:pybind11_status",
"@com_google_absl//absl/memory",
"@pybind11",
],
)

View File

@ -15,9 +15,12 @@ limitations under the License.
#include <memory>
#include "absl/memory/memory.h"
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/profiler/rpc/client/capture_profile.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
namespace py = ::pybind11;
@ -48,10 +51,43 @@ class ProfilerSessionWrapper {
} // namespace
PYBIND11_MODULE(_pywrap_profiler_session, m) {
PYBIND11_MODULE(_pywrap_profiler, m) {
py::class_<ProfilerSessionWrapper> profiler_session_class(m,
"ProfilerSession");
profiler_session_class.def(py::init<>())
.def("start", &ProfilerSessionWrapper::Start)
.def("stop", &ProfilerSessionWrapper::Stop);
m.def("start_profiler_server", [](int port) {
auto profiler_server = absl::make_unique<tensorflow::ProfilerServer>();
profiler_server->StartProfilerServer(port);
// Intentionally release profiler server. Should transfer ownership to
// caller instead.
profiler_server.release();
});
m.def("trace", [](const char* service_addr, const char* logdir,
const char* worker_list, bool include_dataset_ops,
int duration_ms, int num_tracing_attempts) {
tensorflow::Status status =
tensorflow::profiler::ValidateHostPortPair(service_addr);
tensorflow::MaybeRaiseRegisteredFromStatus(status);
status = tensorflow::profiler::Trace(service_addr, logdir, worker_list,
include_dataset_ops, duration_ms,
num_tracing_attempts);
tensorflow::MaybeRaiseRegisteredFromStatus(status);
});
m.def("monitor", [](const char* service_addr, int duration_ms,
int monitoring_level, bool display_timestamp) {
tensorflow::Status status =
tensorflow::profiler::ValidateHostPortPair(service_addr);
tensorflow::MaybeRaiseRegisteredFromStatus(status);
tensorflow::string content;
status = tensorflow::profiler::Monitor(service_addr, duration_ms,
monitoring_level, display_timestamp,
&content);
tensorflow::MaybeRaiseRegisteredFromStatus(status);
return content;
});
};

View File

@ -18,10 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python import pywrap_tfe
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors
from tensorflow.python.profiler.internal import _pywrap_profiler
def trace(service_addr,
@ -45,10 +42,8 @@ def trace(service_addr,
Raises:
UnavailableError: If no trace event is collected.
"""
if not pywrap_tfe.TFE_ProfilerClientStartTracing(
service_addr, logdir, worker_list, True, duration_ms,
num_tracing_attempts):
raise errors.UnavailableError(None, None, 'No trace event is collected.')
_pywrap_profiler.trace(service_addr, logdir, worker_list, True, duration_ms,
num_tracing_attempts)
def monitor(service_addr, duration_ms, level=1):
@ -59,14 +54,10 @@ def monitor(service_addr, duration_ms, level=1):
Args:
service_addr: Address of profiler service e.g. localhost:6009.
duration_ms: Duration of monitoring in ms.
level: Choose a monitoring level between 1 and 2 to monitor your
job. Level 2 is more verbose than level 1 and shows more metrics.
level: Choose a monitoring level between 1 and 2 to monitor your job. Level
2 is more verbose than level 1 and shows more metrics.
Returns:
A string of monitoring output.
"""
with c_api_util.tf_buffer() as buffer_:
pywrap_tfe.TFE_ProfilerClientMonitor(service_addr, duration_ms, level, True,
buffer_)
return pywrap_tf_session.TF_GetBuffer(buffer_)
return _pywrap_profiler.monitor(service_addr, duration_ms, level, True)

View File

@ -501,30 +501,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
},
py::return_value_policy::reference);
// Profiler Logic
m.def("TFE_StartProfilerServer", &TFE_StartProfilerServer);
m.def(
"TFE_ProfilerClientStartTracing",
[](const char* service_addr, const char* logdir, const char* worker_list,
bool include_dataset_ops, int duration_ms, int num_tracing_attempts) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
bool output = TFE_ProfilerClientStartTracing(
service_addr, logdir, worker_list, include_dataset_ops, duration_ms,
num_tracing_attempts, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
return output;
});
m.def("TFE_ProfilerClientMonitor",
[](const char* service_addr, int duration_ms, int monitoring_level,
bool display_timestamp, TF_Buffer& result) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ProfilerClientMonitor(service_addr, duration_ms, monitoring_level,
display_timestamp, &result, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_OpNameGetAttrType",
[](py::handle& ctx, const char* op_or_function_name,
const char* attr_name) {

View File

@ -304,6 +304,7 @@ tensorflow::profiler::AnnotationStack::ThreadAnnotationStack
[profiler_session_impl] # profiler
tensorflow::ProfilerSession::Create
tensorflow::ProfilerSession::CollectData
tensorflow::ProfilerSession::SerializeToString
tensorflow::ProfilerSession::Status
tensorflow::ProfilerSession::~ProfilerSession