Decouple ProfilerSession wrapper from pywrap_tfe

PiperOrigin-RevId: 293882403
Change-Id: I947e32807447460b6fc7ca1b19bf9ca276c3e994
This commit is contained in:
Jose Baiocchi 2020-02-07 13:28:53 -08:00 committed by TensorFlower Gardener
parent e78c47f2a7
commit 21bb9be2c1
16 changed files with 137 additions and 185 deletions

View File

@ -589,6 +589,7 @@ tf_cc_shared_object(
"//tensorflow/core:gpu_runtime_impl",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
"//tensorflow/core:lib_internal_impl",
"//tensorflow/core/profiler/lib:profiler_session_impl",
"//tensorflow/stream_executor:stream_executor_impl",
"//tensorflow:tf_framework_version_script.lds",
] + tf_additional_binary_deps(),

View File

@ -82,8 +82,6 @@ tf_cuda_library(
"//tensorflow/core/distributed_runtime:remote_device",
"//tensorflow/core/distributed_runtime:server_lib",
"//tensorflow/core/distributed_runtime:worker_env",
"//tensorflow/core/profiler/lib:profiler_lib",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core:gpu_runtime",
],
alwayslink = 1,
@ -130,7 +128,6 @@ tf_cuda_library(
"//tensorflow/core/common_runtime/eager:eager_operation",
"//tensorflow/core/common_runtime/eager:kernel_and_device",
"//tensorflow/core/common_runtime/eager:tensor_handle",
"//tensorflow/core/profiler/lib:profiler_session",
],
)

View File

@ -47,27 +47,6 @@ void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
->Handle());
}
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
return profiler->profiler->Status().ok();
}
void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
TF_Status* status) {
string content;
status->status = profiler->profiler->SerializeToString(&content);
void* data = tensorflow::port::Malloc(content.length());
content.copy(static_cast<char*>(data), content.length(), 0);
buf->data = data;
buf->length = content.length();
buf->data_deallocator = [](void* data, size_t length) {
tensorflow::port::Free(data);
};
}
void TFE_StartProfilerServer(int port) {
// Release child thread intentionally. The child thread can be terminated by
// terminating the main thread.

View File

@ -37,23 +37,6 @@ TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
TF_Status* status);
// A profiler which will start profiling when creating the object and will stop
// when the object is destroyed. It will profile all operations run under the
// given TFE_Context. Multiple instance of it can be created, but at most one
// of them will profile for each TFE_Context.
// Thread-safety: TFE_Profiler is thread-safe.
typedef struct TFE_Profiler TFE_Profiler;
TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler();
TF_CAPI_EXPORT extern bool TFE_ProfilerIsOk(TFE_Profiler* profiler);
TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
// The output string is a binary string of tensorflow.tpu.Trace. User can write
// the string to file for offline analysis by tensorboard.
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler,
TF_Buffer* buf,
TF_Status* status);
// Start a profiler grpc server which listens to specified port. It will start
// the server on its own thread. It can be shutdown by terminating tensorflow.
// It can be used in both Eager mode and graph mode. Creating multiple profiler

View File

@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/protobuf/trace_events.pb.h"
using tensorflow::string;
@ -39,88 +38,6 @@ static bool HasSubstr(absl::string_view base, absl::string_view substr) {
return ok;
}
void ExecuteWithProfiling(bool async) {
TF_Status* status = TF_NewStatus();
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
TFE_Context* ctx = TFE_NewContext(opts, status);
TFE_Profiler* profiler = TFE_NewProfiler();
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteContextOptions(opts);
TFE_TensorHandle* m = TestMatrixTensorHandle();
TFE_Op* matmul = MatMulOp(ctx, m, m);
TFE_TensorHandle* retvals[1] = {nullptr};
int num_retvals = 1;
// Run op on GPU if it is present.
string gpu_device_name;
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status);
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
}
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteOp(matmul);
TFE_DeleteTensorHandle(m);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
ASSERT_EQ(1, num_retvals);
TF_Buffer* profiler_result = TF_NewBuffer();
if (async) {
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
TFE_ExecutorWaitForAllPendingNodes(executor, status);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TFE_DeleteExecutor(executor);
}
TFE_ProfilerSerializeToString(profiler, profiler_result, status);
TFE_DeleteProfiler(profiler);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
profiler::Trace profile_proto;
EXPECT_TRUE(profile_proto.ParseFromString(
{reinterpret_cast<const char*>(profiler_result->data),
profiler_result->length}));
string profile_proto_str = profile_proto.DebugString();
#ifndef TENSORFLOW_USE_ROCM
// TODO(rocm): enable once GPU profiling is supported in ROCm mode
if (!gpu_device_name.empty()) {
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
}
#endif
// "/host:CPU" is collected by TraceMe
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul"));
TF_DeleteBuffer(profiler_result);
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
TFE_DeleteTensorHandle(retvals[0]);
TFE_DeleteContext(ctx);
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
float product[4] = {0};
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
TF_DeleteTensor(t);
EXPECT_EQ(7, product[0]);
EXPECT_EQ(10, product[1]);
EXPECT_EQ(15, product[2]);
EXPECT_EQ(22, product[3]);
TF_DeleteStatus(status);
}
TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); }
TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); }
TEST(CAPI, MultipleProfilerSession) {
TFE_Profiler* profiler1 = TFE_NewProfiler();
EXPECT_TRUE(TFE_ProfilerIsOk(profiler1));
TFE_Profiler* profiler2 = TFE_NewProfiler();
EXPECT_FALSE(TFE_ProfilerIsOk(profiler2));
TFE_DeleteProfiler(profiler1);
TFE_DeleteProfiler(profiler2);
}
TEST(CAPI, MonitoringCounter0) {
TF_Status* status = TF_NewStatus();
auto* counter =

View File

@ -48,7 +48,6 @@ limitations under the License.
#include "tensorflow/core/lib/monitoring/sampler.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/core/public/version.h"
struct TFE_ContextOptions {
@ -93,12 +92,6 @@ struct TFE_Op {
tensorflow::EagerOperation operation;
};
struct TFE_Profiler {
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
std::unique_ptr<tensorflow::ProfilerSession> profiler;
};
struct TFE_MonitoringCounterCell {
tensorflow::monitoring::CounterCell cell;
};

View File

@ -2293,6 +2293,7 @@ tf_cuda_library(
"//tensorflow/core/platform/default/build_config:platformlib",
"//tensorflow/core/profiler/internal:annotation_stack_impl",
"//tensorflow/core/profiler/internal:traceme_recorder_impl",
"//tensorflow/core/profiler/internal:profiler_factory_impl",
"//tensorflow/core/profiler/lib:annotated_traceme",
"//tensorflow/core/profiler/lib:traceme",
"//tensorflow/core/util:einsum_op_util",

View File

@ -441,14 +441,31 @@ cc_library(
cc_library(
name = "profiler_factory",
srcs = ["profiler_factory.cc"],
hdrs = ["profiler_factory.h"],
deps = [
":profiler_interface",
] + if_static([
":profiler_factory_impl",
]),
)
# Linked into libtensorflow_framework.so via :framework_internal_impl.
cc_library(
name = "profiler_factory_impl",
srcs = [
"profiler_factory.cc",
"profiler_factory.h",
],
visibility = [
"//tensorflow/core:__pkg__",
],
deps = [
":profiler_interface",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
],
alwayslink = True,
)
filegroup(
@ -532,14 +549,3 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"profiler_interface.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)

View File

@ -1,3 +1,4 @@
load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
package(
@ -10,9 +11,31 @@ package(
cc_library(
name = "profiler_session",
srcs = ["profiler_session.cc"],
hdrs = ["profiler_session.h"],
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/core/profiler/internal:profiler_interface",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
] + if_static([
":profiler_session_impl",
]),
)
# Linked directly into ":tensorflow_framework".
# Unlike other "_impl" targets in the profiler, this one depends on "core:framework" so linking it
# to "core:framework_internal_impl" causes a circular dependency.
cc_library(
name = "profiler_session_impl",
srcs = [
"profiler_session.cc",
"profiler_session.h",
],
visibility = [
"//tensorflow:__pkg__",
"//tensorflow/python:__pkg__",
],
deps = [
":profiler_utils",
"//tensorflow/core:lib",
@ -33,6 +56,7 @@ cc_library(
"//tensorflow/core/profiler/utils:xplane_utils",
],
}),
alwayslink = True,
)
tf_cuda_library(
@ -47,17 +71,6 @@ tf_cuda_library(
alwayslink = True,
)
filegroup(
name = "pywrap_required_hdrs",
srcs = [
"profiler_session.h",
],
visibility = [
"//tensorflow/core:__pkg__",
"//tensorflow/python:__pkg__",
],
)
cc_library(
name = "traceme",
hdrs = ["traceme.h"],

View File

@ -5770,7 +5770,7 @@ filegroup(
"//tensorflow/core/profiler/internal:annotation_stack_impl", # profiler
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
"//tensorflow/core/profiler/internal:traceme_recorder_impl", # profiler
"//tensorflow/core/profiler/lib:profiler_session", # tfe
"//tensorflow/core/profiler/lib:profiler_session_impl", # profiler
"//tensorflow/core/util:port", # util_port
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
"//tensorflow/lite/toco/python:toco_python_api", # toco
@ -7950,8 +7950,6 @@ tf_python_pybind_extension(
"//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
"//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
"//tensorflow/core/framework:pywrap_required_hdrs",
"//tensorflow/core/profiler/internal:pywrap_required_hdrs",
"//tensorflow/core/profiler/lib:pywrap_required_hdrs",
"//tensorflow/python/eager:pywrap_required_hdrs",
],
module_name = "_pywrap_tfe",
@ -7971,7 +7969,6 @@ tf_python_pybind_extension(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/platform:platform",
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
] + if_static(
extra_deps = [
"//tensorflow/compiler/jit:flags",

View File

@ -204,10 +204,9 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":context",
"//tensorflow/python:c_api_util",
"//tensorflow/python:pywrap_tf_session",
"//tensorflow/python:pywrap_tfe",
"//tensorflow/python:util",
"//tensorflow/python/profiler/internal:_pywrap_profiler_session",
],
)

View File

@ -40,11 +40,11 @@ import threading
from tensorflow.python import _pywrap_events_writer
from tensorflow.python import pywrap_tfe
from tensorflow.python.client import pywrap_tf_session
from tensorflow.python.eager import context
from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import errors
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.profiler.internal import _pywrap_profiler_session
from tensorflow.python.util import compat
_profiler = None
@ -75,11 +75,14 @@ def start():
raise ProfilerAlreadyRunningError('Another profiler is running.')
if context.default_execution_mode == context.EAGER_MODE:
context.ensure_initialized()
_profiler = pywrap_tfe.TFE_NewProfiler()
if not pywrap_tfe.TFE_ProfilerIsOk(_profiler):
_profiler = _pywrap_profiler_session.ProfilerSession()
try:
_profiler.start()
except errors.AlreadyExistsError:
logging.warning('Another profiler session is running which is probably '
'created by profiler server. Please avoid using profiler '
'server and profiler APIs at the same time.')
raise ProfilerAlreadyRunningError('Another profiler is running.')
def stop():
@ -100,10 +103,7 @@ def stop():
'Cannot stop profiling. No profiler is running.')
if context.default_execution_mode == context.EAGER_MODE:
context.context().executor.wait()
with c_api_util.tf_buffer() as buffer_:
pywrap_tfe.TFE_ProfilerSerializeToString(_profiler, buffer_)
result = pywrap_tf_session.TF_GetBuffer(buffer_)
pywrap_tfe.TFE_DeleteProfiler(_profiler)
result = _profiler.stop()
_profiler = None
_run_num += 1
return result

View File

@ -119,3 +119,22 @@ tf_python_pybind_extension(
"@pybind11",
],
)
tf_pybind_cc_library_wrapper(
name = "profiler_session_headers",
deps = ["//tensorflow/core/profiler/lib:profiler_session"],
)
tf_python_pybind_extension(
name = "_pywrap_profiler_session",
srcs = ["profiler_session_wrapper.cc"],
features = ["-layering_check"],
module_name = "_pywrap_profiler_session",
visibility = ["//tensorflow/python/eager:__pkg__"],
deps = [
":profiler_session_headers",
"//tensorflow/core:lib",
"//tensorflow/python:pybind11_status",
"@pybind11",
],
)

View File

@ -0,0 +1,57 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "include/pybind11/pybind11.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/profiler_session.h"
#include "tensorflow/python/lib/core/pybind11_status.h"
namespace py = ::pybind11;
namespace {
class ProfilerSessionWrapper {
public:
void Start() {
session_ = tensorflow::ProfilerSession::Create();
tensorflow::MaybeRaiseRegisteredFromStatus(session_->Status());
}
py::bytes Stop() {
tensorflow::string content;
if (session_ != nullptr) {
tensorflow::Status status = session_->SerializeToString(&content);
session_.reset();
tensorflow::MaybeRaiseRegisteredFromStatus(status);
}
// The content is not valid UTF-8, so it must be converted to bytes.
return py::bytes(content);
}
private:
std::unique_ptr<tensorflow::ProfilerSession> session_;
};
} // namespace
PYBIND11_MODULE(_pywrap_profiler_session, m) {
py::class_<ProfilerSessionWrapper> profiler_session_class(m,
"ProfilerSession");
profiler_session_class.def(py::init<>())
.def("start", &ProfilerSessionWrapper::Start)
.def("stop", &ProfilerSessionWrapper::Stop);
};

View File

@ -42,7 +42,6 @@ namespace py = pybind11;
PYBIND11_MAKE_OPAQUE(TFE_Executor);
PYBIND11_MAKE_OPAQUE(TFE_ContextOptions);
PYBIND11_MAKE_OPAQUE(TFE_CancellationManager);
PYBIND11_MAKE_OPAQUE(TFE_Profiler);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0);
PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1);
@ -318,7 +317,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
m, "TFE_MonitoringSampler2");
py::class_<TFE_CancellationManager> TFE_CancellationManager_class(
m, "TFE_CancellationManager");
py::class_<TFE_Profiler> TFE_Profiler_class(m, "TFE_Profiler");
py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
py::class_<TF_Function> TF_Function_class(m, "TF_Function");
@ -504,17 +502,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
py::return_value_policy::reference);
// Profiler Logic
m.def("TFE_NewProfiler", &TFE_NewProfiler,
py::return_value_policy::reference);
m.def("TFE_ProfilerIsOk", &TFE_ProfilerIsOk);
m.def("TFE_DeleteProfiler", &TFE_DeleteProfiler);
m.def("TFE_ProfilerSerializeToString",
[](TFE_Profiler& profiler, TF_Buffer& buf) {
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
TFE_ProfilerSerializeToString(&profiler, &buf, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
});
m.def("TFE_StartProfilerServer", &TFE_StartProfilerServer);
m.def(
"TFE_ProfilerClientStartTracing",

View File

@ -182,9 +182,6 @@ TFE_Py_SetEagerContext
tensorflow::EagerExecutor::~EagerExecutor
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
[profiler_session] # tfe
tensorflow::ProfilerSession::~ProfilerSession
[tf_status_helper] # tfe
tensorflow::Set_TF_Status_from_Status
@ -304,3 +301,9 @@ tensorflow::profiler::TraceMeRecorder::Record
[annotation_stack_impl] # profiler
tensorflow::profiler::AnnotationStack::ThreadAnnotationStack
[profiler_session_impl] # profiler
tensorflow::ProfilerSession::Create
tensorflow::ProfilerSession::SerializeToString
tensorflow::ProfilerSession::Status
tensorflow::ProfilerSession::~ProfilerSession