Move 'enabled' from _pywrap_traceme to trace
PiperOrigin-RevId: 316134360 Change-Id: I256ea280e7ddcd2df3853a943082ff63a52dfffa
This commit is contained in:
parent
795362accd
commit
1e2a941351
@ -32,8 +32,7 @@ from tensorflow.python.framework import op_callbacks
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.profiler import traceme
|
||||
from tensorflow.python.profiler.internal import _pywrap_traceme
|
||||
from tensorflow.python.profiler import trace
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
|
||||
@ -270,8 +269,8 @@ def _constant_impl(
|
||||
"""Implementation of constant."""
|
||||
ctx = context.context()
|
||||
if ctx.executing_eagerly():
|
||||
if _pywrap_traceme.enabled:
|
||||
with traceme.TraceMe("tf.constant"):
|
||||
if trace.enabled:
|
||||
with trace.Trace("tf.constant"):
|
||||
return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
|
||||
return _constant_eager_impl(ctx, value, dtype, shape, verify_shape)
|
||||
|
||||
|
@ -224,7 +224,10 @@ py_library(
|
||||
name = "trace",
|
||||
srcs = ["trace.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//tensorflow:internal"],
|
||||
visibility = [
|
||||
"//perftools/accelerators/xprof/xprofilez/integration_tests:__pkg__",
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/python:tf_export",
|
||||
"//tensorflow/python/profiler/internal:_pywrap_traceme",
|
||||
|
@ -199,9 +199,9 @@ void PythonHooks::ClearProfilerInAllThreads() {
|
||||
|
||||
void PythonHooks::EnableTraceMe(bool enable) {
|
||||
const char* kModuleName =
|
||||
"tensorflow.python.profiler.internal._pywrap_traceme";
|
||||
"tensorflow.python.profiler.trace";
|
||||
auto trace_module = py::module::import(kModuleName);
|
||||
trace_module.attr("enabled") = enable;
|
||||
trace_module.attr("enabled") = py::bool_(enable);
|
||||
}
|
||||
|
||||
} // namespace profiler
|
||||
|
@ -23,10 +23,8 @@ namespace py = ::pybind11;
|
||||
using ::tensorflow::profiler::TraceMeWrapper;
|
||||
|
||||
PYBIND11_MODULE(_pywrap_traceme, m) {
|
||||
// This variable will be modified by PythonHooks::Start/Stop(). such
|
||||
// arrangement will reduce the number of calls through pybind11.
|
||||
m.attr("enabled") = py::bool_(false);
|
||||
py::class_<TraceMeWrapper>(m, "TraceMe", py::module_local())
|
||||
.def(py::init<const py::str&, const py::kwargs&>())
|
||||
.def("SetMetadata", &TraceMeWrapper::SetMetadata);
|
||||
.def("SetMetadata", &TraceMeWrapper::SetMetadata)
|
||||
.def("Stop", &TraceMeWrapper::Stop);
|
||||
};
|
||||
|
@ -21,6 +21,10 @@ from __future__ import print_function
|
||||
from tensorflow.python.profiler.internal import _pywrap_traceme
|
||||
from tensorflow.python.util.tf_export import tf_export
|
||||
|
||||
# This variable is modified by PythonHooks::Start/Stop() in C++. Such
|
||||
# arrangement will reduce the number of calls through pybind11.
|
||||
enabled = False
|
||||
|
||||
|
||||
@tf_export('profiler.experimental.Trace', v1=[])
|
||||
class Trace(object):
|
||||
@ -72,7 +76,7 @@ class Trace(object):
|
||||
The example above uses the keyword argument "step_num" to specify the
|
||||
training step being traced.
|
||||
"""
|
||||
if _pywrap_traceme.enabled:
|
||||
if enabled:
|
||||
# Creating _pywrap_traceme.TraceMe starts the clock.
|
||||
self._traceme = _pywrap_traceme.TraceMe(name, **kwargs)
|
||||
else:
|
||||
@ -117,5 +121,5 @@ class Trace(object):
|
||||
self._traceme.SetMetadata(**kwargs)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# Deallocating _pywrap_traceme.TraceMe stops the clock.
|
||||
self._traceme = None
|
||||
if self._traceme:
|
||||
self._traceme.Stop()
|
||||
|
Loading…
Reference in New Issue
Block a user