diff --git a/tensorflow/core/profiler/internal/cpu/python_tracer.cc b/tensorflow/core/profiler/internal/cpu/python_tracer.cc index d684cb8f768..4233c5fdd72 100644 --- a/tensorflow/core/profiler/internal/cpu/python_tracer.cc +++ b/tensorflow/core/profiler/internal/cpu/python_tracer.cc @@ -23,7 +23,6 @@ limitations under the License. #include "tensorflow/core/profiler/profiler_options.pb.h" #include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tensorflow/core/protobuf/config.pb.h" -#include "tensorflow/core/util/env_var.h" #include "tensorflow/python/profiler/internal/python_hooks.h" namespace tensorflow { @@ -34,7 +33,8 @@ namespace { // the events to TraceMeRecorder. class PythonTracer : public ProfilerInterface { public: - explicit PythonTracer() = default; + explicit PythonTracer(const PythonHooksOptions& options) + : options_(options) {} ~PythonTracer() override; // Starts recording TraceMes. @@ -51,6 +51,7 @@ class PythonTracer : public ProfilerInterface { private: bool recording_ = false; + const PythonHooksOptions options_; TF_DISALLOW_COPY_AND_ASSIGN(PythonTracer); }; @@ -66,7 +67,7 @@ Status PythonTracer::Start() { } VLOG(1) << __FUNCTION__; recording_ = true; - PythonHooks::GetSingleton()->Start(); + PythonHooks::GetSingleton()->Start(options_); return Status::OK(); } @@ -75,7 +76,7 @@ Status PythonTracer::Stop() { return errors::Internal("TraceMeRecorder not started"); } VLOG(1) << __FUNCTION__; - PythonHooks::GetSingleton()->Stop(); + PythonHooks::GetSingleton()->Stop(options_); recording_ = false; return Status::OK(); } @@ -105,18 +106,15 @@ Status PythonTracer::CollectData(XSpace* space) { // Not in anonymous namespace for testing purposes. std::unique_ptr<ProfilerInterface> CreatePythonTracer( const ProfileOptions& options) { - if (options.python_tracer_level() == 0) return nullptr; - // This ProfilerInterface rely on TraceMeRecorder to be active. - if (options.host_tracer_level() == 0) return nullptr; - return absl::make_unique<PythonTracer>(); + PythonHooksOptions pyhooks_options; + pyhooks_options.enable_trace_python_function = + options.python_tracer_level() && options.host_tracer_level(); + pyhooks_options.enable_python_traceme = options.host_tracer_level() != 0; + return absl::make_unique<PythonTracer>(pyhooks_options); } auto register_python_tracer_factory = [] { - bool enable; - TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_OSS_PYTHON_TRACER", true, &enable)); - if (enable) { - RegisterProfilerFactory(&CreatePythonTracer); - } + RegisterProfilerFactory(&CreatePythonTracer); return 0; }(); diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD index ffc090a4676..7f9e4512c5a 100644 --- a/tensorflow/python/profiler/BUILD +++ b/tensorflow/python/profiler/BUILD @@ -48,6 +48,8 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/python:errors", + "//tensorflow/python:platform", "//tensorflow/python:util", "//tensorflow/python/profiler/internal:_pywrap_profiler", ], diff --git a/tensorflow/python/profiler/internal/python_hooks.cc b/tensorflow/python/profiler/internal/python_hooks.cc index 7c25f402f74..f367372a0ed 100644 --- a/tensorflow/python/profiler/internal/python_hooks.cc +++ b/tensorflow/python/profiler/internal/python_hooks.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/python/profiler/internal/python_hooks.h" +#include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "tensorflow/core/platform/path.h" @@ -44,16 +45,30 @@ PythonHooks* PythonHooks::GetSingleton() { return singleton; } -void PythonHooks::Start() { - PyGILState_STATE gil_state = PyGILState_Ensure(); - SetProfilerInAllThreads(); - PyGILState_Release(gil_state); +void PythonHooks::Start(const PythonHooksOptions& option) { + if (option.enable_python_traceme || option.enable_trace_python_function) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + if (option.enable_trace_python_function) { + SetProfilerInAllThreads(); + } + if (option.enable_python_traceme) { + EnableTraceMe(true); + } + PyGILState_Release(gil_state); + } } -void PythonHooks::Stop() { - PyGILState_STATE gil_state = PyGILState_Ensure(); - ClearProfilerInAllThreads(); - PyGILState_Release(gil_state); +void PythonHooks::Stop(const PythonHooksOptions& option) { + if (option.enable_python_traceme || option.enable_trace_python_function) { + PyGILState_STATE gil_state = PyGILState_Ensure(); + if (option.enable_trace_python_function) { + ClearProfilerInAllThreads(); + } + if (option.enable_python_traceme) { + EnableTraceMe(false); + } + PyGILState_Release(gil_state); + } } void PythonHooks::Finalize() { tracemes_.clear(); } @@ -180,5 +195,12 @@ void PythonHooks::ClearProfilerInAllThreads() { ThreadingSetProfile(py::none()); } +void PythonHooks::EnableTraceMe(bool enable) { + const char* kModuleName = + "tensorflow.python.profiler.internal._pywrap_traceme"; + auto trace_module = py::module::import(kModuleName); + trace_module.attr("enabled") = enable; +} + } // namespace profiler } // namespace tensorflow diff --git a/tensorflow/python/profiler/internal/python_hooks.h b/tensorflow/python/profiler/internal/python_hooks.h index 8a9ce645ca9..582edf4a93b 100644 --- a/tensorflow/python/profiler/internal/python_hooks.h +++ b/tensorflow/python/profiler/internal/python_hooks.h @@ -30,19 +30,26 @@ namespace profiler { namespace py = ::pybind11; +struct PythonHooksOptions { + bool enable_trace_python_function = false; + bool enable_python_traceme = true; +}; + // Singleton for tracing python function calls. class PythonHooks { public: static PythonHooks* GetSingleton(); - void Start(); - void Stop(); + void Start(const PythonHooksOptions& option); + void Stop(const PythonHooksOptions& option); void Finalize(); void ProfileSlow(const py::object& frame, const string& event, const py::object& arg); void ProfileFast(PyFrameObject* frame, int what, PyObject* arg); private: + void EnableTraceMe(bool enable); + void SetProfilerInAllThreads(); void ClearProfilerInAllThreads(); diff --git a/tensorflow/python/profiler/internal/traceme_wrapper.cc b/tensorflow/python/profiler/internal/traceme_wrapper.cc index 32a1f423918..bf8a9ba495a 100644 --- a/tensorflow/python/profiler/internal/traceme_wrapper.cc +++ b/tensorflow/python/profiler/internal/traceme_wrapper.cc @@ -23,8 +23,10 @@ namespace py = ::pybind11; using ::tensorflow::profiler::TraceMeWrapper; PYBIND11_MODULE(_pywrap_traceme, m) { + // This variable will be modified by PythonHooks::Start/Stop(). such + // arrangement will reduce the number of calls through pybind11. + m.attr("enabled") = py::bool_(false); py::class_<TraceMeWrapper>(m, "TraceMe", py::module_local()) .def(py::init<const py::str&, const py::kwargs&>()) - .def("SetMetadata", &TraceMeWrapper::SetMetadata) - .def_static("IsEnabled", &TraceMeWrapper::IsEnabled); + .def("SetMetadata", &TraceMeWrapper::SetMetadata); }; diff --git a/tensorflow/python/profiler/trace.py b/tensorflow/python/profiler/trace.py index ea4eb060488..1fdba2abe13 100644 --- a/tensorflow/python/profiler/trace.py +++ b/tensorflow/python/profiler/trace.py @@ -72,7 +72,7 @@ class Trace(object): The example above uses the keyword argument "step_num" to specify the training step being traced. """ - if _pywrap_traceme.TraceMe.IsEnabled(): + if _pywrap_traceme.enabled: # Creating _pywrap_traceme.TraceMe starts the clock. self._traceme = _pywrap_traceme.TraceMe(name, **kwargs) else: