Do TraceMe kwargs encoding in C++

PiperOrigin-RevId: 312329330
Change-Id: I1e7a30e9953b289dece0582cd4041a2769ff1901
This commit is contained in:
Jose Baiocchi 2020-05-19 12:25:34 -07:00 committed by TensorFlower Gardener
parent b7735095de
commit 34a68f2752
4 changed files with 39 additions and 46 deletions

View File

@ -196,19 +196,6 @@ class TraceMe {
#endif #endif
} }
// Appends new_metadata to the payload.
// This overload should only be used by other TraceMe APIs.
// Prefer the overload above instead.
void AppendMetadata(absl::string_view new_metadata) {
#if !defined(IS_MOBILE_PLATFORM)
if (TF_PREDICT_FALSE(start_time_ != kUntracedActivity)) {
if (TF_PREDICT_TRUE(TraceMeRecorder::Active())) {
traceme_internal::AppendMetadata(&no_init_.name, new_metadata);
}
}
#endif
}
// Static API, for use when scoped objects are inconvenient. // Static API, for use when scoped objects are inconvenient.
// Record the start time of an activity. // Record the start time of an activity.

View File

@ -224,10 +224,8 @@ py_library(
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//tensorflow:internal"], visibility = ["//tensorflow:internal"],
deps = [ deps = [
"//tensorflow/python:util", "//tensorflow/python:tf_export",
"//tensorflow/python/profiler/internal:_pywrap_traceme", "//tensorflow/python/profiler/internal:_pywrap_traceme",
"//tensorflow/python/types",
"@six_archive//:six",
], ],
) )

View File

@ -16,9 +16,12 @@ limitations under the License.
#include <string> #include <string>
#include <utility> #include <utility>
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h" #include "absl/strings/string_view.h"
#include "absl/types/optional.h" #include "absl/types/optional.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h" #include "tensorflow/core/platform/types.h"
#include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/profiler/lib/traceme.h"
@ -26,16 +29,41 @@ namespace py = pybind11;
namespace { namespace {
// Converts kwargs to strings and appends them to name encoded as TraceMe
// metadata.
TF_ATTRIBUTE_ALWAYS_INLINE inline void AppendMetadata(
std::string* name, const py::kwargs& kwargs) {
name->push_back('#');
for (const auto& kv : kwargs) {
absl::StrAppend(name, std::string(py::str(kv.first)), "=",
std::string(py::str(kv.second)), ",");
}
name->back() = '#';
}
// Helper to implement TraceMe as a context manager in Python. // Helper to implement TraceMe as a context manager in Python.
class TraceMeWrapper { class TraceMeWrapper {
public: public:
explicit TraceMeWrapper(const std::string& name) : name_(name) {} explicit TraceMeWrapper(py::str name, py::kwargs kwargs)
: name_(std::move(name)), kwargs_(std::move(kwargs)) {}
void Enter() { traceme_.emplace(std::move(name_)); } void Enter() {
traceme_.emplace([this]() {
std::string name(name_);
if (!kwargs_.empty()) {
AppendMetadata(&name, kwargs_);
}
return name;
});
}
void SetMetadata(const std::string& new_metadata) { void SetMetadata(py::kwargs kwargs) {
if (TF_PREDICT_TRUE(traceme_)) { if (TF_PREDICT_TRUE(traceme_.has_value() && !kwargs.empty())) {
traceme_->AppendMetadata(absl::string_view(new_metadata)); traceme_->AppendMetadata([&kwargs]() {
std::string metadata;
AppendMetadata(&metadata, kwargs);
return metadata;
});
} }
} }
@ -44,7 +72,8 @@ class TraceMeWrapper {
static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); } static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); }
private: private:
tensorflow::string name_; py::str name_;
py::kwargs kwargs_;
absl::optional<tensorflow::profiler::TraceMe> traceme_; absl::optional<tensorflow::profiler::TraceMe> traceme_;
}; };
@ -52,7 +81,7 @@ class TraceMeWrapper {
PYBIND11_MODULE(_pywrap_traceme, m) { PYBIND11_MODULE(_pywrap_traceme, m) {
py::class_<TraceMeWrapper> traceme_class(m, "TraceMe"); py::class_<TraceMeWrapper> traceme_class(m, "TraceMe");
traceme_class.def(py::init<const std::string&>()) traceme_class.def(py::init<py::str, py::kwargs>())
.def("Enter", &TraceMeWrapper::Enter) .def("Enter", &TraceMeWrapper::Enter)
.def("Exit", &TraceMeWrapper::Exit) .def("Exit", &TraceMeWrapper::Exit)
.def("SetMetadata", &TraceMeWrapper::SetMetadata) .def("SetMetadata", &TraceMeWrapper::SetMetadata)

View File

@ -18,29 +18,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import six
from tensorflow.python.profiler.internal import _pywrap_traceme from tensorflow.python.profiler.internal import _pywrap_traceme
from tensorflow.python.util.tf_export import tf_export from tensorflow.python.util.tf_export import tf_export
def encode_metadata(metadata):
"""Encodes the given metadata to a string.
Args:
metadata: in key-value pairs.
Returns:
The encoded string.
"""
if not metadata:
return ''
content = []
for key, value in six.iteritems(metadata):
content.append('%s=%s'%(key, value))
return '#' + ','.join(content) + '#'
@tf_export('profiler.experimental.Trace', v1=[]) @tf_export('profiler.experimental.Trace', v1=[])
class Trace(object): class Trace(object):
"""Context manager that generates a trace event in the profiler. """Context manager that generates a trace event in the profiler.
@ -92,8 +73,7 @@ class Trace(object):
training step being traced. training step being traced.
""" """
if _pywrap_traceme.TraceMe.IsEnabled(): if _pywrap_traceme.TraceMe.IsEnabled():
name += encode_metadata(kwargs) self._traceme = _pywrap_traceme.TraceMe(name, **kwargs)
self._traceme = _pywrap_traceme.TraceMe(name)
else: else:
self._traceme = None self._traceme = None
@ -134,8 +114,7 @@ class Trace(object):
to measure the entire duration of call()). to measure the entire duration of call()).
""" """
if self._traceme and kwargs: if self._traceme and kwargs:
additional_metadata = encode_metadata(kwargs) self._traceme.SetMetadata(**kwargs)
self._traceme.SetMetadata(additional_metadata)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if self._traceme: if self._traceme: