Use python tracer to control TraceMe in python launguage. This should have better performance than go through pybind11.

PiperOrigin-RevId: 315547199
Change-Id: I64c4d9f5dce6a23fbeed7fcde10c7a8e839494a4
This commit is contained in:
A. Unique TensorFlower 2020-06-09 13:22:07 -07:00 committed by TensorFlower Gardener
parent b4b83222d4
commit 6eff291a05
6 changed files with 57 additions and 26 deletions
tensorflow

View File

@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/profiler/profiler_options.pb.h"
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/python/profiler/internal/python_hooks.h"
namespace tensorflow {
@ -34,7 +33,8 @@ namespace {
// the events to TraceMeRecorder.
class PythonTracer : public ProfilerInterface {
public:
explicit PythonTracer() = default;
explicit PythonTracer(const PythonHooksOptions& options)
: options_(options) {}
~PythonTracer() override;
// Starts recording TraceMes.
@ -51,6 +51,7 @@ class PythonTracer : public ProfilerInterface {
private:
bool recording_ = false;
const PythonHooksOptions options_;
TF_DISALLOW_COPY_AND_ASSIGN(PythonTracer);
};
@ -66,7 +67,7 @@ Status PythonTracer::Start() {
}
VLOG(1) << __FUNCTION__;
recording_ = true;
PythonHooks::GetSingleton()->Start();
PythonHooks::GetSingleton()->Start(options_);
return Status::OK();
}
@ -75,7 +76,7 @@ Status PythonTracer::Stop() {
return errors::Internal("TraceMeRecorder not started");
}
VLOG(1) << __FUNCTION__;
PythonHooks::GetSingleton()->Stop();
PythonHooks::GetSingleton()->Stop(options_);
recording_ = false;
return Status::OK();
}
@ -105,18 +106,15 @@ Status PythonTracer::CollectData(XSpace* space) {
// Not in anonymous namespace for testing purposes.
std::unique_ptr<ProfilerInterface> CreatePythonTracer(
const ProfileOptions& options) {
if (options.python_tracer_level() == 0) return nullptr;
// This ProfilerInterface rely on TraceMeRecorder to be active.
if (options.host_tracer_level() == 0) return nullptr;
return absl::make_unique<PythonTracer>();
PythonHooksOptions pyhooks_options;
pyhooks_options.enable_trace_python_function =
options.python_tracer_level() && options.host_tracer_level();
pyhooks_options.enable_python_traceme = options.host_tracer_level() != 0;
return absl::make_unique<PythonTracer>(pyhooks_options);
}
auto register_python_tracer_factory = [] {
bool enable;
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_OSS_PYTHON_TRACER", true, &enable));
if (enable) {
RegisterProfilerFactory(&CreatePythonTracer);
}
RegisterProfilerFactory(&CreatePythonTracer);
return 0;
}();

View File

@ -48,6 +48,8 @@ py_library(
srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"],
deps = [
"//tensorflow/python:errors",
"//tensorflow/python:platform",
"//tensorflow/python:util",
"//tensorflow/python/profiler/internal:_pywrap_profiler",
],

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/python/profiler/internal/python_hooks.h"
#include "absl/strings/string_view.h"
#include "absl/strings/strip.h"
#include "tensorflow/core/platform/path.h"
@ -44,16 +45,30 @@ PythonHooks* PythonHooks::GetSingleton() {
return singleton;
}
void PythonHooks::Start() {
PyGILState_STATE gil_state = PyGILState_Ensure();
SetProfilerInAllThreads();
PyGILState_Release(gil_state);
void PythonHooks::Start(const PythonHooksOptions& option) {
if (option.enable_python_traceme || option.enable_trace_python_function) {
PyGILState_STATE gil_state = PyGILState_Ensure();
if (option.enable_trace_python_function) {
SetProfilerInAllThreads();
}
if (option.enable_python_traceme) {
EnableTraceMe(true);
}
PyGILState_Release(gil_state);
}
}
void PythonHooks::Stop() {
PyGILState_STATE gil_state = PyGILState_Ensure();
ClearProfilerInAllThreads();
PyGILState_Release(gil_state);
void PythonHooks::Stop(const PythonHooksOptions& option) {
if (option.enable_python_traceme || option.enable_trace_python_function) {
PyGILState_STATE gil_state = PyGILState_Ensure();
if (option.enable_trace_python_function) {
ClearProfilerInAllThreads();
}
if (option.enable_python_traceme) {
EnableTraceMe(false);
}
PyGILState_Release(gil_state);
}
}
void PythonHooks::Finalize() { tracemes_.clear(); }
@ -180,5 +195,12 @@ void PythonHooks::ClearProfilerInAllThreads() {
ThreadingSetProfile(py::none());
}
void PythonHooks::EnableTraceMe(bool enable) {
const char* kModuleName =
"tensorflow.python.profiler.internal._pywrap_traceme";
auto trace_module = py::module::import(kModuleName);
trace_module.attr("enabled") = enable;
}
} // namespace profiler
} // namespace tensorflow

View File

@ -30,19 +30,26 @@ namespace profiler {
namespace py = ::pybind11;
struct PythonHooksOptions {
bool enable_trace_python_function = false;
bool enable_python_traceme = true;
};
// Singleton for tracing python function calls.
class PythonHooks {
public:
static PythonHooks* GetSingleton();
void Start();
void Stop();
void Start(const PythonHooksOptions& option);
void Stop(const PythonHooksOptions& option);
void Finalize();
void ProfileSlow(const py::object& frame, const string& event,
const py::object& arg);
void ProfileFast(PyFrameObject* frame, int what, PyObject* arg);
private:
void EnableTraceMe(bool enable);
void SetProfilerInAllThreads();
void ClearProfilerInAllThreads();

View File

@ -23,8 +23,10 @@ namespace py = ::pybind11;
using ::tensorflow::profiler::TraceMeWrapper;
PYBIND11_MODULE(_pywrap_traceme, m) {
// This variable will be modified by PythonHooks::Start/Stop(). such
// arrangement will reduce the number of calls through pybind11.
m.attr("enabled") = py::bool_(false);
py::class_<TraceMeWrapper>(m, "TraceMe", py::module_local())
.def(py::init<const py::str&, const py::kwargs&>())
.def("SetMetadata", &TraceMeWrapper::SetMetadata)
.def_static("IsEnabled", &TraceMeWrapper::IsEnabled);
.def("SetMetadata", &TraceMeWrapper::SetMetadata);
};

View File

@ -72,7 +72,7 @@ class Trace(object):
The example above uses the keyword argument "step_num" to specify the
training step being traced.
"""
if _pywrap_traceme.TraceMe.IsEnabled():
if _pywrap_traceme.enabled:
# Creating _pywrap_traceme.TraceMe starts the clock.
self._traceme = _pywrap_traceme.TraceMe(name, **kwargs)
else: