Do TraceMe kwargs encoding in C++
PiperOrigin-RevId: 312329330 Change-Id: I1e7a30e9953b289dece0582cd4041a2769ff1901
This commit is contained in:
parent
b7735095de
commit
34a68f2752
@ -196,19 +196,6 @@ class TraceMe {
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// Appends new_metadata to the payload.
|
|
||||||
// This overload should only be used by other TraceMe APIs.
|
|
||||||
// Prefer the overload above instead.
|
|
||||||
void AppendMetadata(absl::string_view new_metadata) {
|
|
||||||
#if !defined(IS_MOBILE_PLATFORM)
|
|
||||||
if (TF_PREDICT_FALSE(start_time_ != kUntracedActivity)) {
|
|
||||||
if (TF_PREDICT_TRUE(TraceMeRecorder::Active())) {
|
|
||||||
traceme_internal::AppendMetadata(&no_init_.name, new_metadata);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Static API, for use when scoped objects are inconvenient.
|
// Static API, for use when scoped objects are inconvenient.
|
||||||
|
|
||||||
// Record the start time of an activity.
|
// Record the start time of an activity.
|
||||||
|
@ -224,10 +224,8 @@ py_library(
|
|||||||
srcs_version = "PY2AND3",
|
srcs_version = "PY2AND3",
|
||||||
visibility = ["//tensorflow:internal"],
|
visibility = ["//tensorflow:internal"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/python:util",
|
"//tensorflow/python:tf_export",
|
||||||
"//tensorflow/python/profiler/internal:_pywrap_traceme",
|
"//tensorflow/python/profiler/internal:_pywrap_traceme",
|
||||||
"//tensorflow/python/types",
|
|
||||||
"@six_archive//:six",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -16,9 +16,12 @@ limitations under the License.
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include "absl/strings/str_cat.h"
|
||||||
#include "absl/strings/string_view.h"
|
#include "absl/strings/string_view.h"
|
||||||
#include "absl/types/optional.h"
|
#include "absl/types/optional.h"
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
|
#include "pybind11/pytypes.h"
|
||||||
|
#include "tensorflow/core/platform/macros.h"
|
||||||
#include "tensorflow/core/platform/types.h"
|
#include "tensorflow/core/platform/types.h"
|
||||||
#include "tensorflow/core/profiler/lib/traceme.h"
|
#include "tensorflow/core/profiler/lib/traceme.h"
|
||||||
|
|
||||||
@ -26,16 +29,41 @@ namespace py = pybind11;
|
|||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
// Converts kwargs to strings and appends them to name encoded as TraceMe
|
||||||
|
// metadata.
|
||||||
|
TF_ATTRIBUTE_ALWAYS_INLINE inline void AppendMetadata(
|
||||||
|
std::string* name, const py::kwargs& kwargs) {
|
||||||
|
name->push_back('#');
|
||||||
|
for (const auto& kv : kwargs) {
|
||||||
|
absl::StrAppend(name, std::string(py::str(kv.first)), "=",
|
||||||
|
std::string(py::str(kv.second)), ",");
|
||||||
|
}
|
||||||
|
name->back() = '#';
|
||||||
|
}
|
||||||
|
|
||||||
// Helper to implement TraceMe as a context manager in Python.
|
// Helper to implement TraceMe as a context manager in Python.
|
||||||
class TraceMeWrapper {
|
class TraceMeWrapper {
|
||||||
public:
|
public:
|
||||||
explicit TraceMeWrapper(const std::string& name) : name_(name) {}
|
explicit TraceMeWrapper(py::str name, py::kwargs kwargs)
|
||||||
|
: name_(std::move(name)), kwargs_(std::move(kwargs)) {}
|
||||||
|
|
||||||
void Enter() { traceme_.emplace(std::move(name_)); }
|
void Enter() {
|
||||||
|
traceme_.emplace([this]() {
|
||||||
|
std::string name(name_);
|
||||||
|
if (!kwargs_.empty()) {
|
||||||
|
AppendMetadata(&name, kwargs_);
|
||||||
|
}
|
||||||
|
return name;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
void SetMetadata(const std::string& new_metadata) {
|
void SetMetadata(py::kwargs kwargs) {
|
||||||
if (TF_PREDICT_TRUE(traceme_)) {
|
if (TF_PREDICT_TRUE(traceme_.has_value() && !kwargs.empty())) {
|
||||||
traceme_->AppendMetadata(absl::string_view(new_metadata));
|
traceme_->AppendMetadata([&kwargs]() {
|
||||||
|
std::string metadata;
|
||||||
|
AppendMetadata(&metadata, kwargs);
|
||||||
|
return metadata;
|
||||||
|
});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -44,7 +72,8 @@ class TraceMeWrapper {
|
|||||||
static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); }
|
static bool IsEnabled() { return tensorflow::profiler::TraceMe::Active(); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
tensorflow::string name_;
|
py::str name_;
|
||||||
|
py::kwargs kwargs_;
|
||||||
absl::optional<tensorflow::profiler::TraceMe> traceme_;
|
absl::optional<tensorflow::profiler::TraceMe> traceme_;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -52,7 +81,7 @@ class TraceMeWrapper {
|
|||||||
|
|
||||||
PYBIND11_MODULE(_pywrap_traceme, m) {
|
PYBIND11_MODULE(_pywrap_traceme, m) {
|
||||||
py::class_<TraceMeWrapper> traceme_class(m, "TraceMe");
|
py::class_<TraceMeWrapper> traceme_class(m, "TraceMe");
|
||||||
traceme_class.def(py::init<const std::string&>())
|
traceme_class.def(py::init<py::str, py::kwargs>())
|
||||||
.def("Enter", &TraceMeWrapper::Enter)
|
.def("Enter", &TraceMeWrapper::Enter)
|
||||||
.def("Exit", &TraceMeWrapper::Exit)
|
.def("Exit", &TraceMeWrapper::Exit)
|
||||||
.def("SetMetadata", &TraceMeWrapper::SetMetadata)
|
.def("SetMetadata", &TraceMeWrapper::SetMetadata)
|
||||||
|
@ -18,29 +18,10 @@ from __future__ import absolute_import
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import six
|
|
||||||
|
|
||||||
from tensorflow.python.profiler.internal import _pywrap_traceme
|
from tensorflow.python.profiler.internal import _pywrap_traceme
|
||||||
from tensorflow.python.util.tf_export import tf_export
|
from tensorflow.python.util.tf_export import tf_export
|
||||||
|
|
||||||
|
|
||||||
def encode_metadata(metadata):
|
|
||||||
"""Encodes the given metadata to a string.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metadata: in key-value pairs.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The encoded string.
|
|
||||||
"""
|
|
||||||
if not metadata:
|
|
||||||
return ''
|
|
||||||
content = []
|
|
||||||
for key, value in six.iteritems(metadata):
|
|
||||||
content.append('%s=%s'%(key, value))
|
|
||||||
return '#' + ','.join(content) + '#'
|
|
||||||
|
|
||||||
|
|
||||||
@tf_export('profiler.experimental.Trace', v1=[])
|
@tf_export('profiler.experimental.Trace', v1=[])
|
||||||
class Trace(object):
|
class Trace(object):
|
||||||
"""Context manager that generates a trace event in the profiler.
|
"""Context manager that generates a trace event in the profiler.
|
||||||
@ -92,8 +73,7 @@ class Trace(object):
|
|||||||
training step being traced.
|
training step being traced.
|
||||||
"""
|
"""
|
||||||
if _pywrap_traceme.TraceMe.IsEnabled():
|
if _pywrap_traceme.TraceMe.IsEnabled():
|
||||||
name += encode_metadata(kwargs)
|
self._traceme = _pywrap_traceme.TraceMe(name, **kwargs)
|
||||||
self._traceme = _pywrap_traceme.TraceMe(name)
|
|
||||||
else:
|
else:
|
||||||
self._traceme = None
|
self._traceme = None
|
||||||
|
|
||||||
@ -134,8 +114,7 @@ class Trace(object):
|
|||||||
to measure the entire duration of call()).
|
to measure the entire duration of call()).
|
||||||
"""
|
"""
|
||||||
if self._traceme and kwargs:
|
if self._traceme and kwargs:
|
||||||
additional_metadata = encode_metadata(kwargs)
|
self._traceme.SetMetadata(**kwargs)
|
||||||
self._traceme.SetMetadata(additional_metadata)
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
if self._traceme:
|
if self._traceme:
|
||||||
|
Loading…
Reference in New Issue
Block a user