Speedup python TraceMe
PiperOrigin-RevId: 313271773 Change-Id: I6358253077190f43059fed416399852bab29dae6
This commit is contained in:
parent
13f50c2b7a
commit
aff44e4ca1
@ -261,7 +261,7 @@ pybind_extension(
|
||||
"//tensorflow/core/profiler/lib:profiler_backends",
|
||||
"//tensorflow/core/profiler/lib:profiler_session",
|
||||
"//tensorflow/core/profiler/rpc:profiler_server",
|
||||
"//tensorflow/python/profiler/internal:traceme_context_manager",
|
||||
"//tensorflow/python/profiler/internal:traceme_wrapper",
|
||||
"//tensorflow/stream_executor:device_memory_allocator",
|
||||
"//tensorflow/stream_executor:platform",
|
||||
] + select({
|
||||
|
@ -64,7 +64,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/profiler/rpc/profiler_server.h"
|
||||
#include "tensorflow/python/profiler/internal/traceme_context_manager.h"
|
||||
#include "tensorflow/python/profiler/internal/traceme_wrapper.h"
|
||||
#include "tensorflow/stream_executor/platform.h"
|
||||
|
||||
namespace xla {
|
||||
@ -72,7 +72,7 @@ namespace {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using ::tensorflow::profiler::TraceMeContextManager;
|
||||
using ::tensorflow::profiler::TraceMeWrapper;
|
||||
|
||||
struct Uniquer {
|
||||
absl::Mutex mu;
|
||||
@ -637,23 +637,19 @@ void BuildProfilerSubmodule(py::module* m) {
|
||||
},
|
||||
py::arg("port"));
|
||||
|
||||
py::class_<TraceMeContextManager> traceme_class(profiler, "TraceMe",
|
||||
py::module_local());
|
||||
py::class_<TraceMeWrapper> traceme_class(profiler, "TraceMe",
|
||||
py::module_local());
|
||||
traceme_class.def(py::init<py::str, py::kwargs>())
|
||||
.def("__enter__",
|
||||
[](py::object self) -> py::object {
|
||||
py::cast<TraceMeContextManager*>(self)->Enter();
|
||||
return self;
|
||||
})
|
||||
.def("__enter__", [](py::object self) -> py::object { return self; })
|
||||
.def("__exit__",
|
||||
[](py::object self, const py::object& ex_type,
|
||||
const py::object& ex_value,
|
||||
const py::object& traceback) -> py::object {
|
||||
py::cast<TraceMeContextManager*>(self)->Exit();
|
||||
py::cast<TraceMeWrapper*>(self)->Stop();
|
||||
return py::none();
|
||||
})
|
||||
.def("set_metadata", &TraceMeContextManager::SetMetadata)
|
||||
.def_static("is_enabled", &TraceMeContextManager::IsEnabled);
|
||||
.def("set_metadata", &TraceMeWrapper::SetMetadata)
|
||||
.def_static("is_enabled", &TraceMeWrapper::IsEnabled);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
@ -86,14 +86,14 @@ tf_python_pybind_extension(
|
||||
"//tensorflow/python/profiler:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":traceme_context_manager",
|
||||
":traceme_wrapper",
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "traceme_context_manager",
|
||||
hdrs = ["traceme_context_manager.h"],
|
||||
name = "traceme_wrapper",
|
||||
hdrs = ["traceme_wrapper.h"],
|
||||
features = ["-layering_check"],
|
||||
visibility = [
|
||||
"//tensorflow/compiler/xla/python:__pkg__",
|
||||
|
@ -13,18 +13,18 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/python/profiler/internal/traceme_wrapper.h"
|
||||
|
||||
#include "pybind11/attr.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "tensorflow/python/profiler/internal/traceme_context_manager.h"
|
||||
|
||||
using ::tensorflow::profiler::TraceMeContextManager;
|
||||
namespace py = ::pybind11;
|
||||
|
||||
using ::tensorflow::profiler::TraceMeWrapper;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_traceme, m) {
|
||||
py::class_<TraceMeContextManager> traceme_class(m, "TraceMe",
|
||||
py::module_local());
|
||||
traceme_class.def(py::init<py::str, py::kwargs>())
|
||||
.def("Enter", &TraceMeContextManager::Enter)
|
||||
.def("Exit", &TraceMeContextManager::Exit)
|
||||
.def("SetMetadata", &TraceMeContextManager::SetMetadata)
|
||||
.def_static("IsEnabled", &TraceMeContextManager::IsEnabled);
|
||||
py::class_<TraceMeWrapper>(m, "TraceMe", py::module_local())
|
||||
.def(py::init<const py::str&, const py::kwargs&>())
|
||||
.def("SetMetadata", &TraceMeWrapper::SetMetadata)
|
||||
.def_static("IsEnabled", &TraceMeWrapper::IsEnabled);
|
||||
};
|
||||
|
@ -12,46 +12,41 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_CONTEXT_MANAGER_
|
||||
#define TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_CONTEXT_MANAGER_
|
||||
#ifndef TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_WRAPPER_
|
||||
#define TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_WRAPPER_
|
||||
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace profiler {
|
||||
|
||||
// Helper to implement TraceMe as a context manager in Python.
|
||||
class TraceMeContextManager {
|
||||
// Wraps TraceMe with an interface that takes python types.
|
||||
class TraceMeWrapper {
|
||||
public:
|
||||
explicit TraceMeContextManager(py::str name, py::kwargs kwargs)
|
||||
: name_(std::move(name)), kwargs_(std::move(kwargs)) {}
|
||||
// pybind11::str and pybind11::kwargs are taken by const reference to avoid
|
||||
// python reference-counting overhead.
|
||||
TraceMeWrapper(const pybind11::str& name, const pybind11::kwargs& kwargs)
|
||||
: traceme_([&]() {
|
||||
std::string name_and_metadata(name);
|
||||
if (!kwargs.empty()) {
|
||||
AppendMetadata(&name_and_metadata, kwargs);
|
||||
}
|
||||
return name_and_metadata;
|
||||
}) {}
|
||||
|
||||
void Enter() {
|
||||
if (IsEnabled()) {
|
||||
traceme_.emplace([this]() {
|
||||
std::string name(name_);
|
||||
if (!kwargs_.empty()) {
|
||||
AppendMetadata(&name, kwargs_);
|
||||
}
|
||||
return name;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void SetMetadata(py::kwargs kwargs) {
|
||||
if (TF_PREDICT_TRUE(traceme_.has_value() && !kwargs.empty())) {
|
||||
traceme_->AppendMetadata([&kwargs]() {
|
||||
// pybind11::kwargs is taken by const reference to avoid python
|
||||
// reference-counting overhead.
|
||||
void SetMetadata(const pybind11::kwargs& kwargs) {
|
||||
if (TF_PREDICT_FALSE(!kwargs.empty())) {
|
||||
traceme_.AppendMetadata([&]() {
|
||||
std::string metadata;
|
||||
AppendMetadata(&metadata, kwargs);
|
||||
return metadata;
|
||||
@ -59,28 +54,27 @@ class TraceMeContextManager {
|
||||
}
|
||||
}
|
||||
|
||||
void Exit() { traceme_.reset(); }
|
||||
void Stop() { traceme_.Stop(); }
|
||||
|
||||
static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); }
|
||||
|
||||
private:
|
||||
// Converts kwargs to strings and appends them to name encoded as TraceMe
|
||||
// metadata.
|
||||
static void AppendMetadata(std::string* name, const py::kwargs& kwargs) {
|
||||
static void AppendMetadata(std::string* name,
|
||||
const pybind11::kwargs& kwargs) {
|
||||
name->push_back('#');
|
||||
for (const auto& kv : kwargs) {
|
||||
absl::StrAppend(name, std::string(py::str(kv.first)), "=",
|
||||
std::string(py::str(kv.second)), ",");
|
||||
absl::StrAppend(name, std::string(pybind11::str(kv.first)), "=",
|
||||
std::string(pybind11::str(kv.second)), ",");
|
||||
}
|
||||
name->back() = '#';
|
||||
}
|
||||
|
||||
py::str name_;
|
||||
py::kwargs kwargs_;
|
||||
absl::optional<tensorflow::profiler::TraceMe> traceme_;
|
||||
tensorflow::profiler::TraceMe traceme_;
|
||||
};
|
||||
|
||||
} // namespace profiler
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_CONTEXT_MANAGER_
|
||||
#endif // TENSORFLOW_PYTHON_PROFILER_INTERNAL_TRACEME_WRAPPER_
|
@ -73,13 +73,13 @@ class Trace(object):
|
||||
training step being traced.
|
||||
"""
|
||||
if _pywrap_traceme.TraceMe.IsEnabled():
|
||||
# Creating _pywrap_traceme.TraceMe starts the clock.
|
||||
self._traceme = _pywrap_traceme.TraceMe(name, **kwargs)
|
||||
else:
|
||||
self._traceme = None
|
||||
|
||||
def __enter__(self):
|
||||
if self._traceme:
|
||||
self._traceme.Enter()
|
||||
# Starting the TraceMe clock here would require an extra Python->C++ call.
|
||||
return self
|
||||
|
||||
def set_metadata(self, **kwargs):
|
||||
@ -117,5 +117,5 @@ class Trace(object):
|
||||
self._traceme.SetMetadata(**kwargs)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._traceme:
|
||||
self._traceme.Exit()
|
||||
# Deallocating _pywrap_traceme.TraceMe stops the clock.
|
||||
self._traceme = None
|
||||
|
Loading…
x
Reference in New Issue
Block a user