From 6eff291a056d06f8c159485f81228f685b6f719c Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Tue, 9 Jun 2020 13:22:07 -0700
Subject: [PATCH] Use python tracer to control TraceMe in python launguage.
 This should have better performance than go through pybind11.

PiperOrigin-RevId: 315547199
Change-Id: I64c4d9f5dce6a23fbeed7fcde10c7a8e839494a4
---
 .../profiler/internal/cpu/python_tracer.cc    | 24 ++++++------
 tensorflow/python/profiler/BUILD              |  2 +
 .../python/profiler/internal/python_hooks.cc  | 38 +++++++++++++++----
 .../python/profiler/internal/python_hooks.h   | 11 +++++-
 .../profiler/internal/traceme_wrapper.cc      |  6 ++-
 tensorflow/python/profiler/trace.py           |  2 +-
 6 files changed, 57 insertions(+), 26 deletions(-)

diff --git a/tensorflow/core/profiler/internal/cpu/python_tracer.cc b/tensorflow/core/profiler/internal/cpu/python_tracer.cc
index d684cb8f768..4233c5fdd72 100644
--- a/tensorflow/core/profiler/internal/cpu/python_tracer.cc
+++ b/tensorflow/core/profiler/internal/cpu/python_tracer.cc
@@ -23,7 +23,6 @@ limitations under the License.
 #include "tensorflow/core/profiler/profiler_options.pb.h"
 #include "tensorflow/core/profiler/protobuf/xplane.pb.h"
 #include "tensorflow/core/protobuf/config.pb.h"
-#include "tensorflow/core/util/env_var.h"
 #include "tensorflow/python/profiler/internal/python_hooks.h"
 
 namespace tensorflow {
@@ -34,7 +33,8 @@ namespace {
 // the events to TraceMeRecorder.
 class PythonTracer : public ProfilerInterface {
  public:
-  explicit PythonTracer() = default;
+  explicit PythonTracer(const PythonHooksOptions& options)
+      : options_(options) {}
   ~PythonTracer() override;
 
   // Starts recording TraceMes.
@@ -51,6 +51,7 @@ class PythonTracer : public ProfilerInterface {
 
  private:
   bool recording_ = false;
+  const PythonHooksOptions options_;
 
   TF_DISALLOW_COPY_AND_ASSIGN(PythonTracer);
 };
@@ -66,7 +67,7 @@ Status PythonTracer::Start() {
   }
   VLOG(1) << __FUNCTION__;
   recording_ = true;
-  PythonHooks::GetSingleton()->Start();
+  PythonHooks::GetSingleton()->Start(options_);
   return Status::OK();
 }
 
@@ -75,7 +76,7 @@ Status PythonTracer::Stop() {
     return errors::Internal("TraceMeRecorder not started");
   }
   VLOG(1) << __FUNCTION__;
-  PythonHooks::GetSingleton()->Stop();
+  PythonHooks::GetSingleton()->Stop(options_);
   recording_ = false;
   return Status::OK();
 }
@@ -105,18 +106,15 @@ Status PythonTracer::CollectData(XSpace* space) {
 // Not in anonymous namespace for testing purposes.
 std::unique_ptr<ProfilerInterface> CreatePythonTracer(
     const ProfileOptions& options) {
-  if (options.python_tracer_level() == 0) return nullptr;
-  // This ProfilerInterface rely on TraceMeRecorder to be active.
-  if (options.host_tracer_level() == 0) return nullptr;
-  return absl::make_unique<PythonTracer>();
+  PythonHooksOptions pyhooks_options;
+  pyhooks_options.enable_trace_python_function =
+      options.python_tracer_level() && options.host_tracer_level();
+  pyhooks_options.enable_python_traceme = options.host_tracer_level() != 0;
+  return absl::make_unique<PythonTracer>(pyhooks_options);
 }
 
 auto register_python_tracer_factory = [] {
-  bool enable;
-  TF_CHECK_OK(ReadBoolFromEnvVar("TF_ENABLE_OSS_PYTHON_TRACER", true, &enable));
-  if (enable) {
-    RegisterProfilerFactory(&CreatePythonTracer);
-  }
+  RegisterProfilerFactory(&CreatePythonTracer);
   return 0;
 }();
 
diff --git a/tensorflow/python/profiler/BUILD b/tensorflow/python/profiler/BUILD
index ffc090a4676..7f9e4512c5a 100644
--- a/tensorflow/python/profiler/BUILD
+++ b/tensorflow/python/profiler/BUILD
@@ -48,6 +48,8 @@ py_library(
     srcs_version = "PY2AND3",
     visibility = ["//tensorflow:internal"],
     deps = [
+        "//tensorflow/python:errors",
+        "//tensorflow/python:platform",
         "//tensorflow/python:util",
         "//tensorflow/python/profiler/internal:_pywrap_profiler",
     ],
diff --git a/tensorflow/python/profiler/internal/python_hooks.cc b/tensorflow/python/profiler/internal/python_hooks.cc
index 7c25f402f74..f367372a0ed 100644
--- a/tensorflow/python/profiler/internal/python_hooks.cc
+++ b/tensorflow/python/profiler/internal/python_hooks.cc
@@ -14,6 +14,7 @@ limitations under the License.
 ==============================================================================*/
 #include "tensorflow/python/profiler/internal/python_hooks.h"
 
+#include "absl/strings/string_view.h"
 #include "absl/strings/strip.h"
 #include "tensorflow/core/platform/path.h"
 
@@ -44,16 +45,30 @@ PythonHooks* PythonHooks::GetSingleton() {
   return singleton;
 }
 
-void PythonHooks::Start() {
-  PyGILState_STATE gil_state = PyGILState_Ensure();
-  SetProfilerInAllThreads();
-  PyGILState_Release(gil_state);
+void PythonHooks::Start(const PythonHooksOptions& option) {
+  if (option.enable_python_traceme || option.enable_trace_python_function) {
+    PyGILState_STATE gil_state = PyGILState_Ensure();
+    if (option.enable_trace_python_function) {
+      SetProfilerInAllThreads();
+    }
+    if (option.enable_python_traceme) {
+      EnableTraceMe(true);
+    }
+    PyGILState_Release(gil_state);
+  }
 }
 
-void PythonHooks::Stop() {
-  PyGILState_STATE gil_state = PyGILState_Ensure();
-  ClearProfilerInAllThreads();
-  PyGILState_Release(gil_state);
+void PythonHooks::Stop(const PythonHooksOptions& option) {
+  if (option.enable_python_traceme || option.enable_trace_python_function) {
+    PyGILState_STATE gil_state = PyGILState_Ensure();
+    if (option.enable_trace_python_function) {
+      ClearProfilerInAllThreads();
+    }
+    if (option.enable_python_traceme) {
+      EnableTraceMe(false);
+    }
+    PyGILState_Release(gil_state);
+  }
 }
 
 void PythonHooks::Finalize() { tracemes_.clear(); }
@@ -180,5 +195,12 @@ void PythonHooks::ClearProfilerInAllThreads() {
   ThreadingSetProfile(py::none());
 }
 
+void PythonHooks::EnableTraceMe(bool enable) {
+  const char* kModuleName =
+      "tensorflow.python.profiler.internal._pywrap_traceme";
+  auto trace_module = py::module::import(kModuleName);
+  trace_module.attr("enabled") = enable;
+}
+
 }  // namespace profiler
 }  // namespace tensorflow
diff --git a/tensorflow/python/profiler/internal/python_hooks.h b/tensorflow/python/profiler/internal/python_hooks.h
index 8a9ce645ca9..582edf4a93b 100644
--- a/tensorflow/python/profiler/internal/python_hooks.h
+++ b/tensorflow/python/profiler/internal/python_hooks.h
@@ -30,19 +30,26 @@ namespace profiler {
 
 namespace py = ::pybind11;
 
+struct PythonHooksOptions {
+  bool enable_trace_python_function = false;
+  bool enable_python_traceme = true;
+};
+
 // Singleton for tracing python function calls.
 class PythonHooks {
  public:
   static PythonHooks* GetSingleton();
 
-  void Start();
-  void Stop();
+  void Start(const PythonHooksOptions& option);
+  void Stop(const PythonHooksOptions& option);
   void Finalize();
   void ProfileSlow(const py::object& frame, const string& event,
                    const py::object& arg);
   void ProfileFast(PyFrameObject* frame, int what, PyObject* arg);
 
  private:
+  void EnableTraceMe(bool enable);
+
   void SetProfilerInAllThreads();
   void ClearProfilerInAllThreads();
 
diff --git a/tensorflow/python/profiler/internal/traceme_wrapper.cc b/tensorflow/python/profiler/internal/traceme_wrapper.cc
index 32a1f423918..bf8a9ba495a 100644
--- a/tensorflow/python/profiler/internal/traceme_wrapper.cc
+++ b/tensorflow/python/profiler/internal/traceme_wrapper.cc
@@ -23,8 +23,10 @@ namespace py = ::pybind11;
 using ::tensorflow::profiler::TraceMeWrapper;
 
 PYBIND11_MODULE(_pywrap_traceme, m) {
+  // This variable will be modified by PythonHooks::Start/Stop(). such
+  // arrangement will reduce the number of calls through pybind11.
+  m.attr("enabled") = py::bool_(false);
   py::class_<TraceMeWrapper>(m, "TraceMe", py::module_local())
       .def(py::init<const py::str&, const py::kwargs&>())
-      .def("SetMetadata", &TraceMeWrapper::SetMetadata)
-      .def_static("IsEnabled", &TraceMeWrapper::IsEnabled);
+      .def("SetMetadata", &TraceMeWrapper::SetMetadata);
 };
diff --git a/tensorflow/python/profiler/trace.py b/tensorflow/python/profiler/trace.py
index ea4eb060488..1fdba2abe13 100644
--- a/tensorflow/python/profiler/trace.py
+++ b/tensorflow/python/profiler/trace.py
@@ -72,7 +72,7 @@ class Trace(object):
       The example above uses the keyword argument "step_num" to specify the
       training step being traced.
     """
-    if _pywrap_traceme.TraceMe.IsEnabled():
+    if _pywrap_traceme.enabled:
       # Creating _pywrap_traceme.TraceMe starts the clock.
       self._traceme = _pywrap_traceme.TraceMe(name, **kwargs)
     else: