Ported tf_stack.extract_stack to C++
This change also removes extract_stack_file_and_line because extract_stack is now efficient enough to be used ~everywhere. def f(n, callback): if n == 0: return callback() else: return f(n - 1, callback) >>> %timeit f(16, lambda: None) # Baseline 1000000 loops, best of 3: 1.09 ?s per loop Before: >>> %timeit f(16, tf_stack.extract_stack_file_and_line) 100000 loops, best of 3: 17.7 ?s per loop >>> %timeit f(16, tf_stack.extract_stack) 100000 loops, best of 3: 18.5 ?s per loop After: >>> %timeit f(16, tf_stack.extract_stack) 100000 loops, best of 3: 3.89 ?s per loop PiperOrigin-RevId: 263784818
This commit is contained in:
parent
c0219bdebd
commit
2917ad1d24
@ -20,6 +20,7 @@ visibility = [
|
||||
]
|
||||
|
||||
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "pybind_extension")
|
||||
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
|
||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||
@ -1197,7 +1198,9 @@ py_library(
|
||||
srcs = ["util/tf_stack.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [],
|
||||
deps = [
|
||||
":_tf_stack",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
@ -4542,6 +4545,19 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
pybind_extension(
|
||||
name = "_tf_stack",
|
||||
srcs = ["util/tf_stack.cc"],
|
||||
copts = ["-fexceptions"],
|
||||
features = ["-use_header_modules"],
|
||||
# TODO(b/138203821): change to "util._tf_stack" once the bug is fixed.
|
||||
module_name = "_tf_stack",
|
||||
deps = [
|
||||
"//third_party/python_runtime:headers", # buildcleaner: keep
|
||||
"@pybind11",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "util",
|
||||
srcs = glob(
|
||||
@ -4559,6 +4575,7 @@ py_library(
|
||||
"//third_party/py/tf_agents:__subpackages__",
|
||||
],
|
||||
deps = [
|
||||
":_tf_stack",
|
||||
"@org_python_pypi_backports_weakref",
|
||||
"@com_google_protobuf//:protobuf_python",
|
||||
"//third_party/py/numpy",
|
||||
@ -4567,6 +4584,15 @@ py_library(
|
||||
] + if_mlir(["//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass"]),
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "tf_stack_test",
|
||||
srcs = ["util/tf_stack_test.py"],
|
||||
additional_deps = [
|
||||
":client_testlib",
|
||||
":util",
|
||||
],
|
||||
)
|
||||
|
||||
tf_py_test(
|
||||
name = "object_identity_test",
|
||||
size = "small",
|
||||
|
@ -33,9 +33,13 @@ from tensorflow.python.util import tf_stack
|
||||
|
||||
def _make_frame_with_filename(op, idx, filename):
|
||||
"""Return a copy of an existing stack frame with a new filename."""
|
||||
stack_frame = list(op._traceback[idx])
|
||||
stack_frame[tf_stack.TB_FILENAME] = filename
|
||||
return tuple(stack_frame)
|
||||
frame = op._traceback[idx]
|
||||
return tf_stack.StackFrame(
|
||||
filename,
|
||||
frame.lineno,
|
||||
frame.name,
|
||||
frame.globals,
|
||||
frame.func_start_lineno)
|
||||
|
||||
|
||||
def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,
|
||||
|
@ -54,11 +54,11 @@ class Registry(object):
|
||||
if not name:
|
||||
name = candidate.__name__
|
||||
if name in self._registry:
|
||||
(filename, line_number, function_name, _, _) = (
|
||||
self._registry[name][_LOCATION_TAG])
|
||||
raise KeyError("Registering two %s with name '%s'! "
|
||||
"(Previous registration was in %s %s:%d)" %
|
||||
(self._name, name, function_name, filename, line_number))
|
||||
frame = self._registry[name][_LOCATION_TAG]
|
||||
raise KeyError(
|
||||
"Registering two %s with name '%s'! "
|
||||
"(Previous registration was in %s %s:%d)" %
|
||||
(self._name, name, frame.name, frame.filename, frame.lineno))
|
||||
|
||||
logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name)
|
||||
# stack trace is [this_function, Register(), user_function,...]
|
||||
|
@ -55,19 +55,21 @@ class TraceableObject(object):
|
||||
# beyond the caller.
|
||||
local_offset = offset + 1
|
||||
|
||||
frame_records = tf_stack.extract_stack_file_and_line(
|
||||
max_length=local_offset + 1)
|
||||
frame_records = tf_stack.extract_stack(
|
||||
limit=local_offset + 1)
|
||||
if not frame_records:
|
||||
return self.FAILURE
|
||||
if len(frame_records) > local_offset:
|
||||
# Negative indexing is one-indexed instead of zero-indexed.
|
||||
negative_offset = -(local_offset + 1)
|
||||
self.filename, self.lineno = frame_records[negative_offset]
|
||||
frame = frame_records[len(frame_records) - (local_offset + 1)]
|
||||
self.filename = frame.filename
|
||||
self.lineno = frame.lineno
|
||||
return self.SUCCESS
|
||||
else:
|
||||
# If the offset is too large then we use the largest offset possible,
|
||||
# meaning we use the outermost stack frame at index 0.
|
||||
self.filename, self.lineno = frame_records[0]
|
||||
frame = frame_records[0]
|
||||
self.filename = frame.filename
|
||||
self.lineno = frame.lineno
|
||||
return self.HEURISTIC_USED
|
||||
|
||||
def copy_metadata(self):
|
||||
|
@ -99,7 +99,7 @@ def _validate_deprecation_args(date, instructions):
|
||||
|
||||
def _call_location(outer=False):
|
||||
"""Returns call location given level up from current call."""
|
||||
stack = tf_stack.extract_stack_file_and_line(max_length=4)
|
||||
stack = tf_stack.extract_stack(limit=4)
|
||||
length = len(stack)
|
||||
if length == 0: # should never happen as we're in a function
|
||||
return 'UNKNOWN'
|
||||
@ -107,7 +107,7 @@ def _call_location(outer=False):
|
||||
if index < 0:
|
||||
index = 0
|
||||
frame = stack[index]
|
||||
return '{}:{}'.format(frame.file, frame.line)
|
||||
return '{}:{}'.format(frame.filename, frame.lineno)
|
||||
|
||||
|
||||
def _wrap_decorator(wrapped_function):
|
||||
|
@ -42,11 +42,11 @@ def _call_location():
|
||||
# We want to get stack frame 3 frames up from current frame,
|
||||
# i.e. above __getattr__, _tfmw_add_deprecation_warning,
|
||||
# and _call_location calls.
|
||||
stack = tf_stack.extract_stack_file_and_line(max_length=4)
|
||||
stack = tf_stack.extract_stack(limit=4)
|
||||
if not stack: # should never happen as we're in a function
|
||||
return 'UNKNOWN'
|
||||
frame = stack[0]
|
||||
return '{}:{}'.format(frame.file, frame.line)
|
||||
return '{}:{}'.format(frame.filename, frame.lineno)
|
||||
|
||||
|
||||
def contains_deprecation_decorator(decorators):
|
||||
|
@ -85,7 +85,7 @@ def make_decorator(target,
|
||||
"""
|
||||
if decorator_name is None:
|
||||
frame = tf_stack.extract_stack(limit=2)[0]
|
||||
decorator_name = frame[2] # Caller's name
|
||||
decorator_name = frame.name
|
||||
decorator = TFDecorator(decorator_name, target, decorator_doc,
|
||||
decorator_argspec)
|
||||
setattr(decorator_func, '_tf_decorator', decorator)
|
||||
|
126
tensorflow/python/util/tf_stack.cc
Normal file
126
tensorflow/python/util/tf_stack.cc
Normal file
@ -0,0 +1,126 @@
|
||||
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <Python.h>
|
||||
#include <frameobject.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
|
||||
#include "include/pybind11/pybind11.h"
|
||||
#include "include/pybind11/stl_bind.h"
|
||||
|
||||
struct StackFrame; // Forward declaration.
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(std::vector<StackFrame>);
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
struct StackFrame {
|
||||
py::str filename;
|
||||
int lineno;
|
||||
py::str name;
|
||||
py::object globals;
|
||||
int func_start_lineno;
|
||||
};
|
||||
|
||||
std::vector<StackFrame> ExtractStack(ssize_t limit, const py::list& mappers,
|
||||
const py::list& filters) {
|
||||
const py::dict& source_map =
|
||||
mappers.size() == 0
|
||||
? py::dict()
|
||||
: mappers[mappers.size() - 1].attr("get_effective_source_map")();
|
||||
const py::set& filtered_filenames =
|
||||
filters.size() == 0
|
||||
? py::set()
|
||||
: filters[filters.size() - 1].attr("get_filtered_filenames")();
|
||||
|
||||
const auto* tstate = PyThreadState_GET();
|
||||
// Drop extract_stack() wrapper-function frame from the result.
|
||||
const PyFrameObject* f = tstate->frame->f_back; // TODO(slebedev): INCREF?
|
||||
|
||||
std::vector<StackFrame> ret;
|
||||
// 16 is somewhat arbitrary, but TensorFlow stack traces tend to be deep.
|
||||
ret.reserve(limit < 0 ? 16 : static_cast<size_t>(limit));
|
||||
for (; f != nullptr && (limit < 0 || ret.size() < limit); f = f->f_back) {
|
||||
PyCodeObject* co = f->f_code;
|
||||
int lineno = PyFrame_GetLineNumber(const_cast<PyFrameObject*>(f));
|
||||
auto filename = py::reinterpret_borrow<py::str>(co->co_filename);
|
||||
auto name = py::reinterpret_borrow<py::str>(co->co_name);
|
||||
|
||||
// TODO(slebedev): consider moving the mappers/filters to C++ as well.
|
||||
if (source_map.size() > 0) {
|
||||
const auto& key = py::make_tuple(filename, lineno);
|
||||
if (source_map.contains(key)) {
|
||||
const py::tuple& mapped = source_map[key];
|
||||
filename = mapped[0];
|
||||
lineno = py::cast<py::int_>(mapped[1]);
|
||||
name = mapped[2];
|
||||
}
|
||||
}
|
||||
|
||||
// Never filter the innermost frame.
|
||||
// TODO(slebedev): upstream py::set::contains to pybind11.
|
||||
if (!ret.empty() &&
|
||||
PySet_Contains(filtered_filenames.ptr(), filename.ptr()))
|
||||
continue;
|
||||
|
||||
const auto& globals = py::reinterpret_borrow<py::object>(f->f_globals);
|
||||
const int func_start_lineno = co->co_firstlineno;
|
||||
ret.push_back({std::move(filename), lineno, std::move(name), globals,
|
||||
func_start_lineno});
|
||||
}
|
||||
|
||||
std::reverse(ret.begin(), ret.end());
|
||||
return ret;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
PYBIND11_MODULE(_tf_stack, m) {
|
||||
// TODO(slebedev): consider dropping convert_stack in favor of
|
||||
// a lazily initialized StackFrame.code property (using linecache).
|
||||
py::class_<StackFrame>(m, "StackFrame")
|
||||
.def(py::init<const py::str&, int, const py::str&, const py::object&,
|
||||
int>())
|
||||
.def_readonly("filename", &StackFrame::filename)
|
||||
.def_readonly("lineno", &StackFrame::lineno)
|
||||
.def_readonly("name", &StackFrame::name)
|
||||
.def_readonly("globals", &StackFrame::globals)
|
||||
.def_readonly("func_start_lineno", &StackFrame::func_start_lineno)
|
||||
.def("__repr__", [](const StackFrame& self) {
|
||||
return py::str(
|
||||
"StackFrame(filename={}, lineno={}, name={}, globals={}, "
|
||||
"func_start_lineno={})")
|
||||
.format(self.filename, self.lineno, self.name, self.globals,
|
||||
self.func_start_lineno);
|
||||
});
|
||||
|
||||
py::bind_vector<std::vector<StackFrame>>(m, "Stack", py::module_local(true));
|
||||
|
||||
m.def("extract_stack", [](const py::object& limit, const py::list& mappers,
|
||||
const py::list& filters) {
|
||||
// In Python 3.X ``traceback.extract_stack`` allows ``limit`` to
|
||||
// either be None or -1.
|
||||
return ExtractStack(limit.is_none() ? -1 : py::cast<ssize_t>(limit),
|
||||
mappers, filters);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -21,11 +21,13 @@ from __future__ import print_function
|
||||
import collections
|
||||
import inspect
|
||||
import linecache
|
||||
import sys
|
||||
import threading
|
||||
|
||||
import six
|
||||
|
||||
# TODO(b/138203821): change to from ...util import ... once the bug is fixed.
|
||||
from tensorflow.python import _tf_stack
|
||||
|
||||
# Generally such lookups should be done using `threading.local()`. See
|
||||
# https://blogs.gnome.org/jamesh/2008/06/11/tls-python/ for a detailed
|
||||
# explanation of why. However the transform stacks are expected to be empty
|
||||
@ -134,11 +136,7 @@ class CurrentModuleFilter(StackTraceFilter):
|
||||
return self._filtered_filenames
|
||||
|
||||
|
||||
EMPTY_FROZEN_MAP = {}
|
||||
EMPTY_FROZEN_SET = frozenset()
|
||||
|
||||
|
||||
def extract_stack(limit=None):
|
||||
def extract_stack(limit=-1):
|
||||
"""A lightweight, extensible re-implementation of traceback.extract_stack.
|
||||
|
||||
NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for
|
||||
@ -151,88 +149,21 @@ def extract_stack(limit=None):
|
||||
limit: A limit on the number of frames to return.
|
||||
|
||||
Returns:
|
||||
A list of 5-tuples
|
||||
(filename, lineno, name, frame_globals, func_start_lineno)
|
||||
corresponding to the call stack of the current thread. The returned tuples
|
||||
have the innermost stack frame at the end, unlike the Python inspect
|
||||
module's stack() function.
|
||||
A sequence of StackFrame objects
|
||||
(filename, lineno, name, globals, func_start_lineno)
|
||||
corresponding to the call stack of the current thread. The returned
|
||||
tuples have the innermost stack frame at the end, unlike the Python
|
||||
inspect module's stack() function.
|
||||
"""
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except ZeroDivisionError:
|
||||
f = sys.exc_info()[2].tb_frame.f_back
|
||||
ret = []
|
||||
length = 0
|
||||
|
||||
# N.B ExtractStack in tf_stack.cc will drop this frame prior to
|
||||
# traversing the stack.
|
||||
thread_key = _get_thread_key()
|
||||
source_mappers = _source_mapper_stacks[thread_key]
|
||||
# TODO(mdan): Use sentinels instead.
|
||||
if source_mappers:
|
||||
source_map = source_mappers[-1].get_effective_source_map()
|
||||
else:
|
||||
source_map = EMPTY_FROZEN_MAP
|
||||
return _tf_stack.extract_stack(
|
||||
limit,
|
||||
_source_mapper_stacks[thread_key],
|
||||
_source_filter_stacks[thread_key])
|
||||
|
||||
source_filters = _source_filter_stacks[thread_key]
|
||||
if source_filters:
|
||||
filtered_filenames = source_filters[-1].get_filtered_filenames()
|
||||
else:
|
||||
filtered_filenames = EMPTY_FROZEN_SET
|
||||
|
||||
while f is not None and (limit is None or length < limit):
|
||||
lineno = f.f_lineno
|
||||
co = f.f_code
|
||||
filename = co.co_filename
|
||||
name = co.co_name
|
||||
frame_globals = f.f_globals
|
||||
func_start_lineno = co.co_firstlineno
|
||||
|
||||
# TODO(mdan): Show some indication that the frame was translated.
|
||||
filename, lineno, name = source_map.get(
|
||||
(filename, lineno), (filename, lineno, name))
|
||||
|
||||
# Note: we never filter the innermost frame.
|
||||
if not (ret and filename in filtered_filenames):
|
||||
ret.append((filename, lineno, name, frame_globals, func_start_lineno))
|
||||
length += 1
|
||||
|
||||
f = f.f_back
|
||||
|
||||
ret.reverse()
|
||||
return ret
|
||||
|
||||
|
||||
FileAndLine = collections.namedtuple('FileAndLine', ['file', 'line'])
|
||||
|
||||
|
||||
def extract_stack_file_and_line(max_length=1000):
|
||||
"""A version of extract_stack that only returns filenames and line numbers.
|
||||
|
||||
Callers often only require filenames and line numbers, and do not need the
|
||||
additional information gathered by extract_stack, as they never call
|
||||
convert_stack.
|
||||
|
||||
As a further optimisation, we allow users to specify a limit on the number of
|
||||
frames examined.
|
||||
|
||||
Args:
|
||||
max_length: The maximum length of stack to extract.
|
||||
|
||||
Returns:
|
||||
A list of FileAndLine objects corresponding to the call stack of the current
|
||||
thread.
|
||||
"""
|
||||
try:
|
||||
raise ZeroDivisionError
|
||||
except ZeroDivisionError:
|
||||
frame = sys.exc_info()[2].tb_frame.f_back
|
||||
ret = []
|
||||
length = 0
|
||||
while frame is not None and length < max_length:
|
||||
ret.append(FileAndLine(frame.f_code.co_filename, frame.f_lineno))
|
||||
length += 1
|
||||
frame = frame.f_back
|
||||
ret.reverse()
|
||||
return ret
|
||||
StackFrame = _tf_stack.StackFrame
|
||||
|
||||
|
||||
def convert_stack(stack, include_func_start_lineno=False):
|
||||
@ -251,16 +182,18 @@ def convert_stack(stack, include_func_start_lineno=False):
|
||||
input tuple.
|
||||
"""
|
||||
def _tuple_generator(): # pylint: disable=missing-docstring
|
||||
for (filename, lineno, name, frame_globals, func_start_lineno) in stack:
|
||||
for frame in stack:
|
||||
filename = frame.filename
|
||||
lineno = frame.lineno
|
||||
linecache.checkcache(filename)
|
||||
line = linecache.getline(filename, lineno, frame_globals)
|
||||
line = linecache.getline(filename, lineno, frame.globals)
|
||||
if line:
|
||||
line = line.strip()
|
||||
else:
|
||||
line = None
|
||||
if include_func_start_lineno:
|
||||
yield (filename, lineno, name, line, func_start_lineno)
|
||||
yield (filename, lineno, frame.name, line, frame.func_start_lineno)
|
||||
else:
|
||||
yield (filename, lineno, name, line)
|
||||
yield (filename, lineno, frame.name, line)
|
||||
|
||||
return tuple(_tuple_generator())
|
||||
|
55
tensorflow/python/util/tf_stack_test.py
Normal file
55
tensorflow/python/util/tf_stack_test.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for functions used to extract and analyze stacks."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import traceback
|
||||
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.util import tf_stack
|
||||
|
||||
|
||||
class TFStackTest(test.TestCase):
|
||||
|
||||
def testLimit(self):
|
||||
self.assertEmpty(tf_stack.extract_stack(limit=0))
|
||||
self.assertLen(tf_stack.extract_stack(limit=1), 1)
|
||||
self.assertEqual(
|
||||
len(tf_stack.extract_stack(limit=-1)),
|
||||
len(tf_stack.extract_stack()))
|
||||
|
||||
def testConsistencyWithTraceback(self):
|
||||
stack, expected_stack = extract_stack()
|
||||
for frame, expected in zip(stack, expected_stack):
|
||||
self.assertEqual(frame, expected)
|
||||
|
||||
def testFormatStack(self):
|
||||
stack, expected_stack = extract_stack()
|
||||
self.assertEqual(
|
||||
traceback.format_list(stack),
|
||||
traceback.format_list(expected_stack))
|
||||
|
||||
|
||||
def extract_stack(limit=None):
|
||||
convert = tf_stack.convert_stack
|
||||
# Both defined on the same line to produce identical stacks.
|
||||
return convert(tf_stack.extract_stack(limit)), traceback.extract_stack(limit)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user