Do TraceMe kwargs encoding in C++
PiperOrigin-RevId: 312329330 Change-Id: I1e7a30e9953b289dece0582cd4041a2769ff1901
This commit is contained in:
parent
b7735095de
commit
34a68f2752
@ -196,19 +196,6 @@ class TraceMe {
|
||||
#endif
|
||||
}
|
||||
|
||||
// Appends new_metadata to the payload.
|
||||
// This overload should only be used by other TraceMe APIs.
|
||||
// Prefer the overload above instead.
|
||||
void AppendMetadata(absl::string_view new_metadata) {
|
||||
#if !defined(IS_MOBILE_PLATFORM)
|
||||
if (TF_PREDICT_FALSE(start_time_ != kUntracedActivity)) {
|
||||
if (TF_PREDICT_TRUE(TraceMeRecorder::Active())) {
|
||||
traceme_internal::AppendMetadata(&no_init_.name, new_metadata);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// Static API, for use when scoped objects are inconvenient.
|
||||
|
||||
// Record the start time of an activity.
|
||||
|
@ -224,10 +224,8 @@ py_library(
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
deps = [
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:tf_export",
|
||||
"//tensorflow/python/profiler/internal:_pywrap_traceme",
|
||||
"//tensorflow/python/types",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -16,9 +16,12 @@ limitations under the License.
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "absl/types/optional.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||
|
||||
@ -26,16 +29,41 @@ namespace py = pybind11;
|
||||
|
||||
namespace {
|
||||
|
||||
// Converts kwargs to strings and appends them to name encoded as TraceMe
|
||||
// metadata.
|
||||
TF_ATTRIBUTE_ALWAYS_INLINE inline void AppendMetadata(
|
||||
std::string* name, const py::kwargs& kwargs) {
|
||||
name->push_back('#');
|
||||
for (const auto& kv : kwargs) {
|
||||
absl::StrAppend(name, std::string(py::str(kv.first)), "=",
|
||||
std::string(py::str(kv.second)), ",");
|
||||
}
|
||||
name->back() = '#';
|
||||
}
|
||||
|
||||
// Helper to implement TraceMe as a context manager in Python.
|
||||
class TraceMeWrapper {
|
||||
public:
|
||||
explicit TraceMeWrapper(const std::string& name) : name_(name) {}
|
||||
explicit TraceMeWrapper(py::str name, py::kwargs kwargs)
|
||||
: name_(std::move(name)), kwargs_(std::move(kwargs)) {}
|
||||
|
||||
void Enter() { traceme_.emplace(std::move(name_)); }
|
||||
void Enter() {
|
||||
traceme_.emplace([this]() {
|
||||
std::string name(name_);
|
||||
if (!kwargs_.empty()) {
|
||||
AppendMetadata(&name, kwargs_);
|
||||
}
|
||||
return name;
|
||||
});
|
||||
}
|
||||
|
||||
void SetMetadata(const std::string& new_metadata) {
|
||||
if (TF_PREDICT_TRUE(traceme_)) {
|
||||
traceme_->AppendMetadata(absl::string_view(new_metadata));
|
||||
void SetMetadata(py::kwargs kwargs) {
|
||||
if (TF_PREDICT_TRUE(traceme_.has_value() && !kwargs.empty())) {
|
||||
traceme_->AppendMetadata([&kwargs]() {
|
||||
std::string metadata;
|
||||
AppendMetadata(&metadata, kwargs);
|
||||
return metadata;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ -44,7 +72,8 @@ class TraceMeWrapper {
|
||||
static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); }
|
||||
|
||||
private:
|
||||
tensorflow::string name_;
|
||||
py::str name_;
|
||||
py::kwargs kwargs_;
|
||||
absl::optional<tensorflow::profiler::TraceMe> traceme_;
|
||||
};
|
||||
|
||||
@ -52,7 +81,7 @@ class TraceMeWrapper {
|
||||
|
||||
PYBIND11_MODULE(_pywrap_traceme, m) {
|
||||
py::class_<TraceMeWrapper> traceme_class(m, "TraceMe");
|
||||
traceme_class.def(py::init<const std::string&>())
|
||||
traceme_class.def(py::init<py::str, py::kwargs>())
|
||||
.def("Enter", &TraceMeWrapper::Enter)
|
||||
.def("Exit", &TraceMeWrapper::Exit)
|
||||
.def("SetMetadata", &TraceMeWrapper::SetMetadata)
|
||||
|
@ -18,29 +18,10 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.python.profiler.internal import _pywrap_traceme
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
def encode_metadata(metadata):
|
||||
"""Encodes the given metadata to a string.
|
||||
|
||||
Args:
|
||||
metadata: in key-value pairs.
|
||||
|
||||
Returns:
|
||||
The encoded string.
|
||||
"""
|
||||
if not metadata:
|
||||
return ''
|
||||
content = []
|
||||
for key, value in six.iteritems(metadata):
|
||||
content.append('%s=%s'%(key, value))
|
||||
return '#' + ','.join(content) + '#'
|
||||
|
||||
|
||||
@tf_export('profiler.experimental.Trace', v1=[])
|
||||
class Trace(object):
|
||||
"""Context manager that generates a trace event in the profiler.
|
||||
@ -92,8 +73,7 @@ class Trace(object):
|
||||
training step being traced.
|
||||
"""
|
||||
if _pywrap_traceme.TraceMe.IsEnabled():
|
||||
name += encode_metadata(kwargs)
|
||||
self._traceme = _pywrap_traceme.TraceMe(name)
|
||||
self._traceme = _pywrap_traceme.TraceMe(name, **kwargs)
|
||||
else:
|
||||
self._traceme = None
|
||||
|
||||
@ -134,8 +114,7 @@ class Trace(object):
|
||||
to measure the entire duration of call()).
|
||||
"""
|
||||
if self._traceme and kwargs:
|
||||
additional_metadata = encode_metadata(kwargs)
|
||||
self._traceme.SetMetadata(additional_metadata)
|
||||
self._traceme.SetMetadata(**kwargs)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if self._traceme:
|
||||
|
Loading…
Reference in New Issue
Block a user