Decouple ProfilerSession wrapper from pywrap_tfe
PiperOrigin-RevId: 293882403 Change-Id: I947e32807447460b6fc7ca1b19bf9ca276c3e994
This commit is contained in:
parent
e78c47f2a7
commit
21bb9be2c1
@ -589,6 +589,7 @@ tf_cc_shared_object(
|
||||
"//tensorflow/core:gpu_runtime_impl",
|
||||
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry_impl",
|
||||
"//tensorflow/core:lib_internal_impl",
|
||||
"//tensorflow/core/profiler/lib:profiler_session_impl",
|
||||
"//tensorflow/stream_executor:stream_executor_impl",
|
||||
"//tensorflow:tf_framework_version_script.lds",
|
||||
] + tf_additional_binary_deps(),
|
||||
|
@ -82,8 +82,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core/distributed_runtime:remote_device",
|
||||
"//tensorflow/core/distributed_runtime:server_lib",
|
||||
"//tensorflow/core/distributed_runtime:worker_env",
|
||||
"//tensorflow/core/profiler/lib:profiler_lib",
|
||||
"//tensorflow/core/profiler/lib:profiler_session",
|
||||
"//tensorflow/core:gpu_runtime",
|
||||
],
|
||||
alwayslink = 1,
|
||||
@ -130,7 +128,6 @@ tf_cuda_library(
|
||||
"//tensorflow/core/common_runtime/eager:eager_operation",
|
||||
"//tensorflow/core/common_runtime/eager:kernel_and_device",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
"//tensorflow/core/profiler/lib:profiler_session",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -47,27 +47,6 @@ void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
|
||||
->Handle());
|
||||
}
|
||||
|
||||
TFE_Profiler* TFE_NewProfiler() { return new TFE_Profiler(); }
|
||||
|
||||
bool TFE_ProfilerIsOk(TFE_Profiler* profiler) {
|
||||
return profiler->profiler->Status().ok();
|
||||
}
|
||||
|
||||
void TFE_DeleteProfiler(TFE_Profiler* profiler) { delete profiler; }
|
||||
|
||||
void TFE_ProfilerSerializeToString(TFE_Profiler* profiler, TF_Buffer* buf,
|
||||
TF_Status* status) {
|
||||
string content;
|
||||
status->status = profiler->profiler->SerializeToString(&content);
|
||||
void* data = tensorflow::port::Malloc(content.length());
|
||||
content.copy(static_cast<char*>(data), content.length(), 0);
|
||||
buf->data = data;
|
||||
buf->length = content.length();
|
||||
buf->data_deallocator = [](void* data, size_t length) {
|
||||
tensorflow::port::Free(data);
|
||||
};
|
||||
}
|
||||
|
||||
void TFE_StartProfilerServer(int port) {
|
||||
// Release child thread intentionally. The child thread can be terminated by
|
||||
// terminating the main thread.
|
||||
|
@ -37,23 +37,6 @@ TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
|
||||
TF_CAPI_EXPORT extern void TFE_OpConsumeInput(TFE_Op* op, TFE_TensorHandle* h,
|
||||
TF_Status* status);
|
||||
|
||||
// A profiler which will start profiling when creating the object and will stop
|
||||
// when the object is destroyed. It will profile all operations run under the
|
||||
// given TFE_Context. Multiple instance of it can be created, but at most one
|
||||
// of them will profile for each TFE_Context.
|
||||
// Thread-safety: TFE_Profiler is thread-safe.
|
||||
typedef struct TFE_Profiler TFE_Profiler;
|
||||
|
||||
TF_CAPI_EXPORT extern TFE_Profiler* TFE_NewProfiler();
|
||||
TF_CAPI_EXPORT extern bool TFE_ProfilerIsOk(TFE_Profiler* profiler);
|
||||
TF_CAPI_EXPORT extern void TFE_DeleteProfiler(TFE_Profiler* profiler);
|
||||
|
||||
// The output string is a binary string of tensorflow.tpu.Trace. User can write
|
||||
// the string to file for offline analysis by tensorboard.
|
||||
TF_CAPI_EXPORT extern void TFE_ProfilerSerializeToString(TFE_Profiler* profiler,
|
||||
TF_Buffer* buf,
|
||||
TF_Status* status);
|
||||
|
||||
// Start a profiler grpc server which listens to specified port. It will start
|
||||
// the server on its own thread. It can be shutdown by terminating tensorflow.
|
||||
// It can be used in both Eager mode and graph mode. Creating multiple profiler
|
||||
|
@ -26,7 +26,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
#include "tensorflow/core/protobuf/trace_events.pb.h"
|
||||
|
||||
using tensorflow::string;
|
||||
|
||||
@ -39,88 +38,6 @@ static bool HasSubstr(absl::string_view base, absl::string_view substr) {
|
||||
return ok;
|
||||
}
|
||||
|
||||
void ExecuteWithProfiling(bool async) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetAsync(opts, static_cast<unsigned char>(async));
|
||||
TFE_Context* ctx = TFE_NewContext(opts, status);
|
||||
TFE_Profiler* profiler = TFE_NewProfiler();
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteContextOptions(opts);
|
||||
|
||||
TFE_TensorHandle* m = TestMatrixTensorHandle();
|
||||
TFE_Op* matmul = MatMulOp(ctx, m, m);
|
||||
TFE_TensorHandle* retvals[1] = {nullptr};
|
||||
int num_retvals = 1;
|
||||
|
||||
// Run op on GPU if it is present.
|
||||
string gpu_device_name;
|
||||
if (GetDeviceName(ctx, &gpu_device_name, "GPU")) {
|
||||
TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status);
|
||||
ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status);
|
||||
}
|
||||
|
||||
TFE_Execute(matmul, &retvals[0], &num_retvals, status);
|
||||
EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteOp(matmul);
|
||||
TFE_DeleteTensorHandle(m);
|
||||
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
ASSERT_EQ(1, num_retvals);
|
||||
TF_Buffer* profiler_result = TF_NewBuffer();
|
||||
if (async) {
|
||||
TFE_Executor* executor = TFE_ContextGetExecutorForThread(ctx);
|
||||
TFE_ExecutorWaitForAllPendingNodes(executor, status);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TFE_DeleteExecutor(executor);
|
||||
}
|
||||
TFE_ProfilerSerializeToString(profiler, profiler_result, status);
|
||||
TFE_DeleteProfiler(profiler);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
profiler::Trace profile_proto;
|
||||
EXPECT_TRUE(profile_proto.ParseFromString(
|
||||
{reinterpret_cast<const char*>(profiler_result->data),
|
||||
profiler_result->length}));
|
||||
string profile_proto_str = profile_proto.DebugString();
|
||||
#ifndef TENSORFLOW_USE_ROCM
|
||||
// TODO(rocm): enable once GPU profiling is supported in ROCm mode
|
||||
if (!gpu_device_name.empty()) {
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/device:GPU:0"));
|
||||
}
|
||||
#endif
|
||||
// "/host:CPU" is collected by TraceMe
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "/host:CPU"));
|
||||
EXPECT_TRUE(HasSubstr(profile_proto_str, "MatMul"));
|
||||
TF_DeleteBuffer(profiler_result);
|
||||
|
||||
TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status);
|
||||
TFE_DeleteTensorHandle(retvals[0]);
|
||||
TFE_DeleteContext(ctx);
|
||||
ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
float product[4] = {0};
|
||||
EXPECT_EQ(sizeof(product), TF_TensorByteSize(t));
|
||||
memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t));
|
||||
TF_DeleteTensor(t);
|
||||
EXPECT_EQ(7, product[0]);
|
||||
EXPECT_EQ(10, product[1]);
|
||||
EXPECT_EQ(15, product[2]);
|
||||
EXPECT_EQ(22, product[3]);
|
||||
TF_DeleteStatus(status);
|
||||
}
|
||||
TEST(CAPI, ExecuteWithTracing) { ExecuteWithProfiling(false); }
|
||||
TEST(CAPI, ExecuteWithTracingAsync) { ExecuteWithProfiling(true); }
|
||||
|
||||
TEST(CAPI, MultipleProfilerSession) {
|
||||
TFE_Profiler* profiler1 = TFE_NewProfiler();
|
||||
EXPECT_TRUE(TFE_ProfilerIsOk(profiler1));
|
||||
|
||||
TFE_Profiler* profiler2 = TFE_NewProfiler();
|
||||
EXPECT_FALSE(TFE_ProfilerIsOk(profiler2));
|
||||
|
||||
TFE_DeleteProfiler(profiler1);
|
||||
TFE_DeleteProfiler(profiler2);
|
||||
}
|
||||
|
||||
TEST(CAPI, MonitoringCounter0) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
auto* counter =
|
||||
|
@ -48,7 +48,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/monitoring/sampler.h"
|
||||
#include "tensorflow/core/platform/mutex.h"
|
||||
#include "tensorflow/core/platform/thread_annotations.h"
|
||||
#include "tensorflow/core/profiler/lib/profiler_session.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
struct TFE_ContextOptions {
|
||||
@ -93,12 +92,6 @@ struct TFE_Op {
|
||||
tensorflow::EagerOperation operation;
|
||||
};
|
||||
|
||||
struct TFE_Profiler {
|
||||
explicit TFE_Profiler() { profiler = tensorflow::ProfilerSession::Create(); }
|
||||
|
||||
std::unique_ptr<tensorflow::ProfilerSession> profiler;
|
||||
};
|
||||
|
||||
struct TFE_MonitoringCounterCell {
|
||||
tensorflow::monitoring::CounterCell cell;
|
||||
};
|
||||
|
@ -2293,6 +2293,7 @@ tf_cuda_library(
|
||||
"//tensorflow/core/platform/default/build_config:platformlib",
|
||||
"//tensorflow/core/profiler/internal:annotation_stack_impl",
|
||||
"//tensorflow/core/profiler/internal:traceme_recorder_impl",
|
||||
"//tensorflow/core/profiler/internal:profiler_factory_impl",
|
||||
"//tensorflow/core/profiler/lib:annotated_traceme",
|
||||
"//tensorflow/core/profiler/lib:traceme",
|
||||
"//tensorflow/core/util:einsum_op_util",
|
||||
|
@ -441,14 +441,31 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "profiler_factory",
|
||||
srcs = ["profiler_factory.cc"],
|
||||
hdrs = ["profiler_factory.h"],
|
||||
deps = [
|
||||
":profiler_interface",
|
||||
] + if_static([
|
||||
":profiler_factory_impl",
|
||||
]),
|
||||
)
|
||||
|
||||
# Linked into libtensorflow_framework.so via :framework_internal_impl.
|
||||
cc_library(
|
||||
name = "profiler_factory_impl",
|
||||
srcs = [
|
||||
"profiler_factory.cc",
|
||||
"profiler_factory.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":profiler_interface",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
|
||||
],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
@ -532,14 +549,3 @@ tf_cc_test(
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"profiler_interface.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
@ -1,3 +1,4 @@
|
||||
load("//tensorflow/core/platform:build_config_root.bzl", "if_static")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
|
||||
|
||||
package(
|
||||
@ -10,9 +11,31 @@ package(
|
||||
|
||||
cc_library(
|
||||
name = "profiler_session",
|
||||
srcs = ["profiler_session.cc"],
|
||||
hdrs = ["profiler_session.h"],
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/core/profiler/internal:profiler_interface",
|
||||
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
] + if_static([
|
||||
":profiler_session_impl",
|
||||
]),
|
||||
)
|
||||
|
||||
# Linked directly into ":tensorflow_framework".
|
||||
# Unlike other "_impl" targets in the profiler, this one depends on "core:framework" so linking it
|
||||
# to "core:framework_internal_impl" causes a circular dependency.
|
||||
cc_library(
|
||||
name = "profiler_session_impl",
|
||||
srcs = [
|
||||
"profiler_session.cc",
|
||||
"profiler_session.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
deps = [
|
||||
":profiler_utils",
|
||||
"//tensorflow/core:lib",
|
||||
@ -33,6 +56,7 @@ cc_library(
|
||||
"//tensorflow/core/profiler/utils:xplane_utils",
|
||||
],
|
||||
}),
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
tf_cuda_library(
|
||||
@ -47,17 +71,6 @@ tf_cuda_library(
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "pywrap_required_hdrs",
|
||||
srcs = [
|
||||
"profiler_session.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow/core:__pkg__",
|
||||
"//tensorflow/python:__pkg__",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "traceme",
|
||||
hdrs = ["traceme.h"],
|
||||
|
@ -5770,7 +5770,7 @@ filegroup(
|
||||
"//tensorflow/core/profiler/internal:annotation_stack_impl", # profiler
|
||||
"//tensorflow/core/profiler/internal:print_model_analysis", # tfprof
|
||||
"//tensorflow/core/profiler/internal:traceme_recorder_impl", # profiler
|
||||
"//tensorflow/core/profiler/lib:profiler_session", # tfe
|
||||
"//tensorflow/core/profiler/lib:profiler_session_impl", # profiler
|
||||
"//tensorflow/core/util:port", # util_port
|
||||
"//tensorflow/core/util/tensor_bundle", # checkpoint_reader
|
||||
"//tensorflow/lite/toco/python:toco_python_api", # toco
|
||||
@ -7950,8 +7950,6 @@ tf_python_pybind_extension(
|
||||
"//tensorflow/core/distributed_runtime:pywrap_required_hdrs",
|
||||
"//tensorflow/core/distributed_runtime/eager:pywrap_required_hdrs",
|
||||
"//tensorflow/core/framework:pywrap_required_hdrs",
|
||||
"//tensorflow/core/profiler/internal:pywrap_required_hdrs",
|
||||
"//tensorflow/core/profiler/lib:pywrap_required_hdrs",
|
||||
"//tensorflow/python/eager:pywrap_required_hdrs",
|
||||
],
|
||||
module_name = "_pywrap_tfe",
|
||||
@ -7971,7 +7969,6 @@ tf_python_pybind_extension(
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:platform",
|
||||
"//tensorflow/core/profiler/protobuf:xplane_proto_cc",
|
||||
] + if_static(
|
||||
extra_deps = [
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
|
@ -204,10 +204,9 @@ py_library(
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
":context",
|
||||
"//tensorflow/python:c_api_util",
|
||||
"//tensorflow/python:pywrap_tf_session",
|
||||
"//tensorflow/python:pywrap_tfe",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/profiler/internal:_pywrap_profiler_session",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -40,11 +40,11 @@ import threading
|
||||
|
||||
from tensorflow.python import _pywrap_events_writer
|
||||
from tensorflow.python import pywrap_tfe
|
||||
from tensorflow.python.client import pywrap_tf_session
|
||||
from tensorflow.python.eager import context
|
||||
from tensorflow.python.framework import c_api_util
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.profiler.internal import _pywrap_profiler_session
|
||||
from tensorflow.python.util import compat
|
||||
|
||||
_profiler = None
|
||||
@ -75,11 +75,14 @@ def start():
|
||||
raise ProfilerAlreadyRunningError('Another profiler is running.')
|
||||
if context.default_execution_mode == context.EAGER_MODE:
|
||||
context.ensure_initialized()
|
||||
_profiler = pywrap_tfe.TFE_NewProfiler()
|
||||
if not pywrap_tfe.TFE_ProfilerIsOk(_profiler):
|
||||
_profiler = _pywrap_profiler_session.ProfilerSession()
|
||||
try:
|
||||
_profiler.start()
|
||||
except errors.AlreadyExistsError:
|
||||
logging.warning('Another profiler session is running which is probably '
|
||||
'created by profiler server. Please avoid using profiler '
|
||||
'server and profiler APIs at the same time.')
|
||||
raise ProfilerAlreadyRunningError('Another profiler is running.')
|
||||
|
||||
|
||||
def stop():
|
||||
@ -100,10 +103,7 @@ def stop():
|
||||
'Cannot stop profiling. No profiler is running.')
|
||||
if context.default_execution_mode == context.EAGER_MODE:
|
||||
context.context().executor.wait()
|
||||
with c_api_util.tf_buffer() as buffer_:
|
||||
pywrap_tfe.TFE_ProfilerSerializeToString(_profiler, buffer_)
|
||||
result = pywrap_tf_session.TF_GetBuffer(buffer_)
|
||||
pywrap_tfe.TFE_DeleteProfiler(_profiler)
|
||||
result = _profiler.stop()
|
||||
_profiler = None
|
||||
_run_num += 1
|
||||
return result
|
||||
|
@ -119,3 +119,22 @@ tf_python_pybind_extension(
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
tf_pybind_cc_library_wrapper(
|
||||
name = "profiler_session_headers",
|
||||
deps = ["//tensorflow/core/profiler/lib:profiler_session"],
|
||||
)
|
||||
|
||||
tf_python_pybind_extension(
|
||||
name = "_pywrap_profiler_session",
|
||||
srcs = ["profiler_session_wrapper.cc"],
|
||||
features = ["-layering_check"],
|
||||
module_name = "_pywrap_profiler_session",
|
||||
visibility = ["//tensorflow/python/eager:__pkg__"],
|
||||
deps = [
|
||||
":profiler_session_headers",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/python:pybind11_status",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,57 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/profiler_session.h"
|
||||
#include "tensorflow/python/lib/core/pybind11_status.h"
|
||||
|
||||
namespace py = ::pybind11;
|
||||
|
||||
namespace {
|
||||
|
||||
class ProfilerSessionWrapper {
|
||||
public:
|
||||
void Start() {
|
||||
session_ = tensorflow::ProfilerSession::Create();
|
||||
tensorflow::MaybeRaiseRegisteredFromStatus(session_->Status());
|
||||
}
|
||||
|
||||
py::bytes Stop() {
|
||||
tensorflow::string content;
|
||||
if (session_ != nullptr) {
|
||||
tensorflow::Status status = session_->SerializeToString(&content);
|
||||
session_.reset();
|
||||
tensorflow::MaybeRaiseRegisteredFromStatus(status);
|
||||
}
|
||||
// The content is not valid UTF-8, so it must be converted to bytes.
|
||||
return py::bytes(content);
|
||||
}
|
||||
|
||||
private:
|
||||
std::unique_ptr<tensorflow::ProfilerSession> session_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(_pywrap_profiler_session, m) {
|
||||
py::class_<ProfilerSessionWrapper> profiler_session_class(m,
|
||||
"ProfilerSession");
|
||||
profiler_session_class.def(py::init<>())
|
||||
.def("start", &ProfilerSessionWrapper::Start)
|
||||
.def("stop", &ProfilerSessionWrapper::Stop);
|
||||
};
|
@ -42,7 +42,6 @@ namespace py = pybind11;
|
||||
PYBIND11_MAKE_OPAQUE(TFE_Executor);
|
||||
PYBIND11_MAKE_OPAQUE(TFE_ContextOptions);
|
||||
PYBIND11_MAKE_OPAQUE(TFE_CancellationManager);
|
||||
PYBIND11_MAKE_OPAQUE(TFE_Profiler);
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter0);
|
||||
PYBIND11_MAKE_OPAQUE(TFE_MonitoringCounter1);
|
||||
@ -318,7 +317,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
m, "TFE_MonitoringSampler2");
|
||||
py::class_<TFE_CancellationManager> TFE_CancellationManager_class(
|
||||
m, "TFE_CancellationManager");
|
||||
py::class_<TFE_Profiler> TFE_Profiler_class(m, "TFE_Profiler");
|
||||
|
||||
py::class_<TF_DeviceList> TF_DeviceList_class(m, "TF_DeviceList");
|
||||
py::class_<TF_Function> TF_Function_class(m, "TF_Function");
|
||||
@ -504,17 +502,6 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
|
||||
py::return_value_policy::reference);
|
||||
|
||||
// Profiler Logic
|
||||
m.def("TFE_NewProfiler", &TFE_NewProfiler,
|
||||
py::return_value_policy::reference);
|
||||
m.def("TFE_ProfilerIsOk", &TFE_ProfilerIsOk);
|
||||
m.def("TFE_DeleteProfiler", &TFE_DeleteProfiler);
|
||||
m.def("TFE_ProfilerSerializeToString",
|
||||
[](TFE_Profiler& profiler, TF_Buffer& buf) {
|
||||
tensorflow::Safe_TF_StatusPtr status =
|
||||
tensorflow::make_safe(TF_NewStatus());
|
||||
TFE_ProfilerSerializeToString(&profiler, &buf, status.get());
|
||||
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
|
||||
});
|
||||
m.def("TFE_StartProfilerServer", &TFE_StartProfilerServer);
|
||||
m.def(
|
||||
"TFE_ProfilerClientStartTracing",
|
||||
|
@ -182,9 +182,6 @@ TFE_Py_SetEagerContext
|
||||
tensorflow::EagerExecutor::~EagerExecutor
|
||||
tensorflow::EagerContext::WaitForAndCloseRemoteContexts
|
||||
|
||||
[profiler_session] # tfe
|
||||
tensorflow::ProfilerSession::~ProfilerSession
|
||||
|
||||
[tf_status_helper] # tfe
|
||||
tensorflow::Set_TF_Status_from_Status
|
||||
|
||||
@ -304,3 +301,9 @@ tensorflow::profiler::TraceMeRecorder::Record
|
||||
|
||||
[annotation_stack_impl] # profiler
|
||||
tensorflow::profiler::AnnotationStack::ThreadAnnotationStack
|
||||
|
||||
[profiler_session_impl] # profiler
|
||||
tensorflow::ProfilerSession::Create
|
||||
tensorflow::ProfilerSession::SerializeToString
|
||||
tensorflow::ProfilerSession::Status
|
||||
tensorflow::ProfilerSession::~ProfilerSession
|
||||
|
Loading…
Reference in New Issue
Block a user