Use python tracer to control TraceMe in python launguage. This should have better performance than go through pybind11.
PiperOrigin-RevId: 315547199 Change-Id: I64c4d9f5dce6a23fbeed7fcde10c7a8e839494a4
This commit is contained in:
parent
b4b83222d4
commit
6eff291a05
tensorflow
core/profiler/internal/cpu
python/profiler
@ -23,7 +23,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/profiler/profiler_options.pb.h"
|
||||
#include "tensorflow/core/profiler/protobuf/xplane.pb.h"
|
||||
#include "tensorflow/core/protobuf/config.pb.h"
|
||||
#include "tensorflow/core/util/env_var.h"
|
||||
#include "tensorflow/python/profiler/internal/python_hooks.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -34,7 +33,8 @@ namespace {
|
||||
// the events to TraceMeRecorder.
|
||||
class PythonTracer : public ProfilerInterface {
|
||||
public:
|
||||
explicit PythonTracer() = default;
|
||||
explicit PythonTracer(const PythonHooksOptions& options)
|
||||
: options_(options) {}
|
||||
~PythonTracer() override;
|
||||
|
||||
// Starts recording TraceMes.
|
||||
@ -51,6 +51,7 @@ class PythonTracer : public ProfilerInterface {
|
||||
|
||||
private:
|
||||
bool recording_ = false;
|
||||
const PythonHooksOptions options_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(PythonTracer);
|
||||
};
|
||||
@ -66,7 +67,7 @@ Status PythonTracer::Start() {
|
||||
}
|
||||
VLOG(1) << __FUNCTION__;
|
||||
recording_ = true;
|
||||
PythonHooks::GetSingleton()->Start();
|
||||
PythonHooks::GetSingleton()->Start(options_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -75,7 +76,7 @@ Status PythonTracer::Stop() {
|
||||
return errors::Internal("TraceMeRecorder not started");
|
||||
}
|
||||
VLOG(1) << __FUNCTION__;
|
||||
PythonHooks::GetSingleton()->Stop();
|
||||
PythonHooks::GetSingleton()->Stop(options_);
|
||||
recording_ = false;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -105,18 +106,15 @@ Status PythonTracer::CollectData(XSpace* space) {
|
||||
// Not in anonymous namespace for testing purposes.
|
||||
std::unique_ptr<ProfilerInterface> CreatePythonTracer(
|
||||
const ProfileOptions& options) {
|
||||
if (options.python_tracer_level() == 0) return nullptr;
|
||||
// This ProfilerInterface rely on TraceMeRecorder to be active.
|
||||
if (options.host_tracer_level() == 0) return nullptr;
|
||||
return absl::make_unique<PythonTracer>();
|
||||
PythonHooksOptions pyhooks_options;
|
||||
pyhooks_options.enable_trace_python_function =
|
||||
options.python_tracer_level() && options.host_tracer_level();
|
||||
pyhooks_options.enable_python_traceme = options.host_tracer_level() != 0;
|
||||
return absl::make_unique<PythonTracer>(pyhooks_options);
|
||||
}
|
||||
|
||||
auto register_python_tracer_factory = [] {
|
||||
bool enable;
|
||||
TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_OSS_PYTHON_TRACER", true, &enable));
|
||||
if (enable) {
|
||||
RegisterProfilerFactory(&CreatePythonTracer);
|
||||
}
|
||||
RegisterProfilerFactory(&CreatePythonTracer);
|
||||
return 0;
|
||||
}();
|
||||
|
||||
|
@ -48,6 +48,8 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:errors",
|
||||
"//tensorflow/python:platform",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/profiler/internal:_pywrap_profiler",
|
||||
],
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/python/profiler/internal/python_hooks.h"
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/strings/strip.h"
|
||||
#include "tensorflow/core/platform/path.h"
|
||||
|
||||
@ -44,16 +45,30 @@ PythonHooks* PythonHooks::GetSingleton() {
|
||||
return singleton;
|
||||
}
|
||||
|
||||
void PythonHooks::Start() {
|
||||
PyGILState_STATE gil_state = PyGILState_Ensure();
|
||||
SetProfilerInAllThreads();
|
||||
PyGILState_Release(gil_state);
|
||||
void PythonHooks::Start(const PythonHooksOptions& option) {
|
||||
if (option.enable_python_traceme || option.enable_trace_python_function) {
|
||||
PyGILState_STATE gil_state = PyGILState_Ensure();
|
||||
if (option.enable_trace_python_function) {
|
||||
SetProfilerInAllThreads();
|
||||
}
|
||||
if (option.enable_python_traceme) {
|
||||
EnableTraceMe(true);
|
||||
}
|
||||
PyGILState_Release(gil_state);
|
||||
}
|
||||
}
|
||||
|
||||
void PythonHooks::Stop() {
|
||||
PyGILState_STATE gil_state = PyGILState_Ensure();
|
||||
ClearProfilerInAllThreads();
|
||||
PyGILState_Release(gil_state);
|
||||
void PythonHooks::Stop(const PythonHooksOptions& option) {
|
||||
if (option.enable_python_traceme || option.enable_trace_python_function) {
|
||||
PyGILState_STATE gil_state = PyGILState_Ensure();
|
||||
if (option.enable_trace_python_function) {
|
||||
ClearProfilerInAllThreads();
|
||||
}
|
||||
if (option.enable_python_traceme) {
|
||||
EnableTraceMe(false);
|
||||
}
|
||||
PyGILState_Release(gil_state);
|
||||
}
|
||||
}
|
||||
|
||||
void PythonHooks::Finalize() { tracemes_.clear(); }
|
||||
@ -180,5 +195,12 @@ void PythonHooks::ClearProfilerInAllThreads() {
|
||||
ThreadingSetProfile(py::none());
|
||||
}
|
||||
|
||||
void PythonHooks::EnableTraceMe(bool enable) {
|
||||
const char* kModuleName =
|
||||
"tensorflow.python.profiler.internal._pywrap_traceme";
|
||||
auto trace_module = py::module::import(kModuleName);
|
||||
trace_module.attr("enabled") = enable;
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
@ -30,19 +30,26 @@ namespace profiler {
|
||||
|
||||
namespace py = ::pybind11;
|
||||
|
||||
struct PythonHooksOptions {
|
||||
bool enable_trace_python_function = false;
|
||||
bool enable_python_traceme = true;
|
||||
};
|
||||
|
||||
// Singleton for tracing python function calls.
|
||||
class PythonHooks {
|
||||
public:
|
||||
static PythonHooks* GetSingleton();
|
||||
|
||||
void Start();
|
||||
void Stop();
|
||||
void Start(const PythonHooksOptions& option);
|
||||
void Stop(const PythonHooksOptions& option);
|
||||
void Finalize();
|
||||
void ProfileSlow(const py::object& frame, const string& event,
|
||||
const py::object& arg);
|
||||
void ProfileFast(PyFrameObject* frame, int what, PyObject* arg);
|
||||
|
||||
private:
|
||||
void EnableTraceMe(bool enable);
|
||||
|
||||
void SetProfilerInAllThreads();
|
||||
void ClearProfilerInAllThreads();
|
||||
|
||||
|
@ -23,8 +23,10 @@ namespace py = ::pybind11;
|
||||
using ::tensorflow::profiler::TraceMeWrapper;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_traceme, m) {
|
||||
// This variable will be modified by PythonHooks::Start/Stop(). such
|
||||
// arrangement will reduce the number of calls through pybind11.
|
||||
m.attr("enabled") = py::bool_(false);
|
||||
py::class_<TraceMeWrapper>(m, "TraceMe", py::module_local())
|
||||
.def(py::init<const py::str&, const py::kwargs&>())
|
||||
.def("SetMetadata", &TraceMeWrapper::SetMetadata)
|
||||
.def_static("IsEnabled", &TraceMeWrapper::IsEnabled);
|
||||
.def("SetMetadata", &TraceMeWrapper::SetMetadata);
|
||||
};
|
||||
|
@ -72,7 +72,7 @@ class Trace(object):
|
||||
The example above uses the keyword argument "step_num" to specify the
|
||||
training step being traced.
|
||||
"""
|
||||
if _pywrap_traceme.TraceMe.IsEnabled():
|
||||
if _pywrap_traceme.enabled:
|
||||
# Creating _pywrap_traceme.TraceMe starts the clock.
|
||||
self._traceme = _pywrap_traceme.TraceMe(name, **kwargs)
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user