Speedup python TraceMe

PiperOrigin-RevId: 313271773
Change-Id: I6358253077190f43059fed416399852bab29dae6
This commit is contained in:
Jose Baiocchi 2020-05-26 14:48:04 -07:00 committed by TensorFlower Gardener
parent 13f50c2b7a
commit aff44e4ca1
6 changed files with 51 additions and 61 deletions

View File

@ -261,7 +261,7 @@ pybind_extension(
"//tensorflow/core/profiler/lib:profiler_backends",
"//tensorflow/core/profiler/lib:profiler_session",
"//tensorflow/core/profiler/rpc:profiler_server",
"//tensorflow/python/profiler/internal:traceme_context_manager",
"//tensorflow/python/profiler/internal:traceme_wrapper",
"//tensorflow/stream_executor:device_memory_allocator",
"//tensorflow/stream_executor:platform",
] + select({

View File

@ -64,7 +64,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/profiler/rpc/profiler_server.h"
#include "tensorflow/python/profiler/internal/traceme_context_manager.h"
#include "tensorflow/python/profiler/internal/traceme_wrapper.h"
#include "tensorflow/stream_executor/platform.h"
namespace xla {
@ -72,7 +72,7 @@ namespace {
namespace py = pybind11;
using ::tensorflow::profiler::TraceMeContextManager;
using ::tensorflow::profiler::TraceMeWrapper;
struct Uniquer {
absl::Mutex mu;
@ -637,23 +637,19 @@ void BuildProfilerSubmodule(py::module* m) {
},
py::arg("port"));
py::class_<TraceMeContextManager> traceme_class(profiler, "TraceMe",
py::class_<TraceMeWrapper> traceme_class(profiler, "TraceMe",
py::module_local());
traceme_class.def(py::init<py::str, py::kwargs>())
.def("__enter__",
[](py::object self) -> py::object {
py::cast<TraceMeContextManager*>(self)->Enter();
return self;
})
.def("__enter__", [](py::object self) -> py::object { return self; })
.def("__exit__",
[](py::object self, const py::object& ex_type,
const py::object& ex_value,
const py::object& traceback) -> py::object {
py::cast<TraceMeContextManager*>(self)->Exit();
py::cast<TraceMeWrapper*>(self)->Stop();
return py::none();
})
.def("set_metadata", &TraceMeContextManager::SetMetadata)
.def_static("is_enabled", &TraceMeContextManager::IsEnabled);
.def("set_metadata", &TraceMeWrapper::SetMetadata)
.def_static("is_enabled", &TraceMeWrapper::IsEnabled);
}
} // namespace

View File

@ -86,14 +86,14 @@ tf_python_pybind_extension(
"//tensorflow/python/profiler:__subpackages__",
],
deps = [
":traceme_context_manager",
":traceme_wrapper",
"@pybind11",
],
)
cc_library(
name = "traceme_context_manager",
hdrs = ["traceme_context_manager.h"],
name = "traceme_wrapper",
hdrs = ["traceme_wrapper.h"],
features = ["-layering_check"],
visibility = [
"//tensorflow/compiler/xla/python:__pkg__",

View File

@ -13,18 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/python/profiler/internal/traceme_wrapper.h"
#include "pybind11/attr.h"
#include "pybind11/pybind11.h"
#include "tensorflow/python/profiler/internal/traceme_context_manager.h"
using ::tensorflow::profiler::TraceMeContextManager;
namespace py = ::pybind11;
using ::tensorflow::profiler::TraceMeWrapper;
PYBIND11_MODULE(_pywrap_traceme, m) {
py::class_<TraceMeContextManager> traceme_class(m, "TraceMe",
py::module_local());
traceme_class.def(py::init<py::str, py::kwargs>())
.def("Enter", &TraceMeContextManager::Enter)
.def("Exit", &TraceMeContextManager::Exit)
.def("SetMetadata", &TraceMeContextManager::SetMetadata)
.def_static("IsEnabled", &TraceMeContextManager::IsEnabled);
py::class_<TraceMeWrapper>(m, "TraceMe", py::module_local())
.def(py::init<const py::str&, const py::kwargs&>())
.def("SetMetadata", &TraceMeWrapper::SetMetadata)
.def_static("IsEnabled", &TraceMeWrapper::IsEnabled);
};

View File

@ -12,46 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_CONTEXT_MANAGER_
#define TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_CONTEXT_MANAGER_
#ifndef TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_WRAPPER_
#define TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_WRAPPER_
#include <string>
#include <utility>
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "pybind11/pytypes.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h"
namespace py = pybind11;
namespace tensorflow {
namespace profiler {
// Helper to implement TraceMe as a context manager in Python.
class TraceMeContextManager {
// Wraps TraceMe with an interface that takes python types.
class TraceMeWrapper {
public:
explicit TraceMeContextManager(py::str name, py::kwargs kwargs)
: name_(std::move(name)), kwargs_(std::move(kwargs)) {}
// pybind11::str and pybind11::kwargs are taken by const reference to avoid
// python reference-counting overhead.
TraceMeWrapper(const pybind11::str& name, const pybind11::kwargs& kwargs)
: traceme_([&]() {
std::string name_and_metadata(name);
if (!kwargs.empty()) {
AppendMetadata(&name_and_metadata, kwargs);
}
return name_and_metadata;
}) {}
void Enter() {
if (IsEnabled()) {
traceme_.emplace([this]() {
std::string name(name_);
if (!kwargs_.empty()) {
AppendMetadata(&name, kwargs_);
}
return name;
});
}
}
void SetMetadata(py::kwargs kwargs) {
if (TF_PREDICT_TRUE(traceme_.has_value() && !kwargs.empty())) {
traceme_->AppendMetadata([&kwargs]() {
// pybind11::kwargs is taken by const reference to avoid python
// reference-counting overhead.
void SetMetadata(const pybind11::kwargs& kwargs) {
if (TF_PREDICT_FALSE(!kwargs.empty())) {
traceme_.AppendMetadata([&]() {
std::string metadata;
AppendMetadata(&metadata, kwargs);
return metadata;
@ -59,28 +54,27 @@ class TraceMeContextManager {
}
}
void Exit() { traceme_.reset(); }
void Stop() { traceme_.Stop(); }
static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); }
private:
// Converts kwargs to strings and appends them to name encoded as TraceMe
// metadata.
static void AppendMetadata(std::string* name, const py::kwargs& kwargs) {
static void AppendMetadata(std::string* name,
const pybind11::kwargs& kwargs) {
name->push_back('#');
for (const auto& kv : kwargs) {
absl::StrAppend(name, std::string(py::str(kv.first)), "=",
std::string(py::str(kv.second)), ",");
absl::StrAppend(name, std::string(pybind11::str(kv.first)), "=",
std::string(pybind11::str(kv.second)), ",");
}
name->back() = '#';
}
py::str name_;
py::kwargs kwargs_;
absl::optional<tensorflow::profiler::TraceMe> traceme_;
tensorflow::profiler::TraceMe traceme_;
};
} // namespace profiler
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_CONTEXT_MANAGER_
#endif // TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_WRAPPER_

View File

@ -73,13 +73,13 @@ class Trace(object):
training step being traced.
"""
if _pywrap_traceme.TraceMe.IsEnabled():
# Creating _pywrap_traceme.TraceMe starts the clock.
self._traceme = _pywrap_traceme.TraceMe(name, **kwargs)
else:
self._traceme = None
def __enter__(self):
if self._traceme:
self._traceme.Enter()
# Starting the TraceMe clock here would require an extra Python->C++ call.
return self
def set_metadata(self, **kwargs):
@ -117,5 +117,5 @@ class Trace(object):
self._traceme.SetMetadata(**kwargs)
def __exit__(self, exc_type, exc_val, exc_tb):
if self._traceme:
self._traceme.Exit()
# Deallocating _pywrap_traceme.TraceMe stops the clock.
self._traceme = None