Ported tf_stack.extract_stack to C++

This change also removes extract_stack_file_and_line because extract_stack
is now efficient enough to be used ~everywhere.

def f(n, callback):
  if n == 0:
    return callback()
  else:
    return f(n - 1, callback)

>>> %timeit f(16, lambda: None)  # Baseline
1000000 loops, best of 3: 1.09 ?s per loop

Before:

>>> %timeit f(16, tf_stack.extract_stack_file_and_line)
100000 loops, best of 3: 17.7 ?s per loop
>>> %timeit f(16, tf_stack.extract_stack)
100000 loops, best of 3: 18.5 ?s per loop

After:

>>> %timeit f(16, tf_stack.extract_stack)
100000 loops, best of 3: 3.89 ?s per loop

PiperOrigin-RevId: 263784818
This commit is contained in:
Sergei Lebedev 2019-08-16 09:23:23 -07:00 committed by TensorFlower Gardener
parent c0219bdebd
commit 2917ad1d24
10 changed files with 255 additions and 109 deletions

View File

@ -20,6 +20,7 @@ visibility = [
]
load("//tensorflow:tensorflow.bzl", "if_mlir", "if_not_v2", "if_not_windows", "py_test", "py_tests", "tf_cc_shared_object", "tf_cuda_library", "tf_gen_op_wrapper_py", "tf_py_build_info_genrule", "tf_py_test")
load("//tensorflow:tensorflow.bzl", "pybind_extension")
load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
@ -1197,7 +1198,9 @@ py_library(
srcs = ["util/tf_stack.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [],
deps = [
":_tf_stack",
],
)
py_library(
@ -4542,6 +4545,19 @@ py_library(
],
)
pybind_extension(
name = "_tf_stack",
srcs = ["util/tf_stack.cc"],
copts = ["-fexceptions"],
features = ["-use_header_modules"],
# TODO(b/138203821): change to "util._tf_stack" once the bug is fixed.
module_name = "_tf_stack",
deps = [
"//third_party/python_runtime:headers", # buildcleaner: keep
"@pybind11",
],
)
py_library(
name = "util",
srcs = glob(
@ -4559,6 +4575,7 @@ py_library(
"//third_party/py/tf_agents:__subpackages__",
],
deps = [
":_tf_stack",
"@org_python_pypi_backports_weakref",
"@com_google_protobuf//:protobuf_python",
"//third_party/py/numpy",
@ -4567,6 +4584,15 @@ py_library(
] + if_mlir(["//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_pass"]),
)
tf_py_test(
name = "tf_stack_test",
srcs = ["util/tf_stack_test.py"],
additional_deps = [
":client_testlib",
":util",
],
)
tf_py_test(
name = "object_identity_test",
size = "small",

View File

@ -33,9 +33,13 @@ from tensorflow.python.util import tf_stack
def _make_frame_with_filename(op, idx, filename):
"""Return a copy of an existing stack frame with a new filename."""
stack_frame = list(op._traceback[idx])
stack_frame[tf_stack.TB_FILENAME] = filename
return tuple(stack_frame)
frame = op._traceback[idx]
return tf_stack.StackFrame(
filename,
frame.lineno,
frame.name,
frame.globals,
frame.func_start_lineno)
def _modify_op_stack_with_filenames(op, num_user_frames, user_filename,

View File

@ -54,11 +54,11 @@ class Registry(object):
if not name:
name = candidate.__name__
if name in self._registry:
(filename, line_number, function_name, _, _) = (
self._registry[name][_LOCATION_TAG])
raise KeyError("Registering two %s with name '%s'! "
"(Previous registration was in %s %s:%d)" %
(self._name, name, function_name, filename, line_number))
frame = self._registry[name][_LOCATION_TAG]
raise KeyError(
"Registering two %s with name '%s'! "
"(Previous registration was in %s %s:%d)" %
(self._name, name, frame.name, frame.filename, frame.lineno))
logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name)
# stack trace is [this_function, Register(), user_function,...]

View File

@ -55,19 +55,21 @@ class TraceableObject(object):
# beyond the caller.
local_offset = offset + 1
frame_records = tf_stack.extract_stack_file_and_line(
max_length=local_offset + 1)
frame_records = tf_stack.extract_stack(
limit=local_offset + 1)
if not frame_records:
return self.FAILURE
if len(frame_records) > local_offset:
# Negative indexing is one-indexed instead of zero-indexed.
negative_offset = -(local_offset + 1)
self.filename, self.lineno = frame_records[negative_offset]
frame = frame_records[len(frame_records) - (local_offset + 1)]
self.filename = frame.filename
self.lineno = frame.lineno
return self.SUCCESS
else:
# If the offset is too large then we use the largest offset possible,
# meaning we use the outermost stack frame at index 0.
self.filename, self.lineno = frame_records[0]
frame = frame_records[0]
self.filename = frame.filename
self.lineno = frame.lineno
return self.HEURISTIC_USED
def copy_metadata(self):

View File

@ -99,7 +99,7 @@ def _validate_deprecation_args(date, instructions):
def _call_location(outer=False):
"""Returns call location given level up from current call."""
stack = tf_stack.extract_stack_file_and_line(max_length=4)
stack = tf_stack.extract_stack(limit=4)
length = len(stack)
if length == 0: # should never happen as we're in a function
return 'UNKNOWN'
@ -107,7 +107,7 @@ def _call_location(outer=False):
if index < 0:
index = 0
frame = stack[index]
return '{}:{}'.format(frame.file, frame.line)
return '{}:{}'.format(frame.filename, frame.lineno)
def _wrap_decorator(wrapped_function):

View File

@ -42,11 +42,11 @@ def _call_location():
# We want to get stack frame 3 frames up from current frame,
# i.e. above __getattr__, _tfmw_add_deprecation_warning,
# and _call_location calls.
stack = tf_stack.extract_stack_file_and_line(max_length=4)
stack = tf_stack.extract_stack(limit=4)
if not stack: # should never happen as we're in a function
return 'UNKNOWN'
frame = stack[0]
return '{}:{}'.format(frame.file, frame.line)
return '{}:{}'.format(frame.filename, frame.lineno)
def contains_deprecation_decorator(decorators):

View File

@ -85,7 +85,7 @@ def make_decorator(target,
"""
if decorator_name is None:
frame = tf_stack.extract_stack(limit=2)[0]
decorator_name = frame[2] # Caller's name
decorator_name = frame.name
decorator = TFDecorator(decorator_name, target, decorator_doc,
decorator_argspec)
setattr(decorator_func, '_tf_decorator', decorator)

View File

@ -0,0 +1,126 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <Python.h>
#include <frameobject.h>
#include <algorithm>
#include <vector>
#include "include/pybind11/pybind11.h"
#include "include/pybind11/stl_bind.h"
struct StackFrame; // Forward declaration.
PYBIND11_MAKE_OPAQUE(std::vector<StackFrame>);
namespace tensorflow {
namespace {
namespace py = pybind11;
struct StackFrame {
py::str filename;
int lineno;
py::str name;
py::object globals;
int func_start_lineno;
};
std::vector<StackFrame> ExtractStack(ssize_t limit, const py::list& mappers,
const py::list& filters) {
const py::dict& source_map =
mappers.size() == 0
? py::dict()
: mappers[mappers.size() - 1].attr("get_effective_source_map")();
const py::set& filtered_filenames =
filters.size() == 0
? py::set()
: filters[filters.size() - 1].attr("get_filtered_filenames")();
const auto* tstate = PyThreadState_GET();
// Drop extract_stack() wrapper-function frame from the result.
const PyFrameObject* f = tstate->frame->f_back; // TODO(slebedev): INCREF?
std::vector<StackFrame> ret;
// 16 is somewhat arbitrary, but TensorFlow stack traces tend to be deep.
ret.reserve(limit < 0 ? 16 : static_cast<size_t>(limit));
for (; f != nullptr && (limit < 0 || ret.size() < limit); f = f->f_back) {
PyCodeObject* co = f->f_code;
int lineno = PyFrame_GetLineNumber(const_cast<PyFrameObject*>(f));
auto filename = py::reinterpret_borrow<py::str>(co->co_filename);
auto name = py::reinterpret_borrow<py::str>(co->co_name);
// TODO(slebedev): consider moving the mappers/filters to C++ as well.
if (source_map.size() > 0) {
const auto& key = py::make_tuple(filename, lineno);
if (source_map.contains(key)) {
const py::tuple& mapped = source_map[key];
filename = mapped[0];
lineno = py::cast<py::int_>(mapped[1]);
name = mapped[2];
}
}
// Never filter the innermost frame.
// TODO(slebedev): upstream py::set::contains to pybind11.
if (!ret.empty() &&
PySet_Contains(filtered_filenames.ptr(), filename.ptr()))
continue;
const auto& globals = py::reinterpret_borrow<py::object>(f->f_globals);
const int func_start_lineno = co->co_firstlineno;
ret.push_back({std::move(filename), lineno, std::move(name), globals,
func_start_lineno});
}
std::reverse(ret.begin(), ret.end());
return ret;
}
} // namespace
PYBIND11_MODULE(_tf_stack, m) {
// TODO(slebedev): consider dropping convert_stack in favor of
// a lazily initialized StackFrame.code property (using linecache).
py::class_<StackFrame>(m, "StackFrame")
.def(py::init<const py::str&, int, const py::str&, const py::object&,
int>())
.def_readonly("filename", &StackFrame::filename)
.def_readonly("lineno", &StackFrame::lineno)
.def_readonly("name", &StackFrame::name)
.def_readonly("globals", &StackFrame::globals)
.def_readonly("func_start_lineno", &StackFrame::func_start_lineno)
.def("__repr__", [](const StackFrame& self) {
return py::str(
"StackFrame(filename={}, lineno={}, name={}, globals={}, "
"func_start_lineno={})")
.format(self.filename, self.lineno, self.name, self.globals,
self.func_start_lineno);
});
py::bind_vector<std::vector<StackFrame>>(m, "Stack", py::module_local(true));
m.def("extract_stack", [](const py::object& limit, const py::list& mappers,
const py::list& filters) {
// In Python 3.X ``traceback.extract_stack`` allows ``limit`` to
// either be None or -1.
return ExtractStack(limit.is_none() ? -1 : py::cast<ssize_t>(limit),
mappers, filters);
});
}
} // namespace tensorflow

View File

@ -21,11 +21,13 @@ from __future__ import print_function
import collections
import inspect
import linecache
import sys
import threading
import six
# TODO(b/138203821): change to from ...util import ... once the bug is fixed.
from tensorflow.python import _tf_stack
# Generally such lookups should be done using `threading.local()`. See
# https://blogs.gnome.org/jamesh/2008/06/11/tls-python/ for a detailed
# explanation of why. However the transform stacks are expected to be empty
@ -134,11 +136,7 @@ class CurrentModuleFilter(StackTraceFilter):
return self._filtered_filenames
EMPTY_FROZEN_MAP = {}
EMPTY_FROZEN_SET = frozenset()
def extract_stack(limit=None):
def extract_stack(limit=-1):
"""A lightweight, extensible re-implementation of traceback.extract_stack.
NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for
@ -151,88 +149,21 @@ def extract_stack(limit=None):
limit: A limit on the number of frames to return.
Returns:
A list of 5-tuples
(filename, lineno, name, frame_globals, func_start_lineno)
corresponding to the call stack of the current thread. The returned tuples
have the innermost stack frame at the end, unlike the Python inspect
module's stack() function.
A sequence of StackFrame objects
(filename, lineno, name, globals, func_start_lineno)
corresponding to the call stack of the current thread. The returned
tuples have the innermost stack frame at the end, unlike the Python
inspect module's stack() function.
"""
try:
raise ZeroDivisionError
except ZeroDivisionError:
f = sys.exc_info()[2].tb_frame.f_back
ret = []
length = 0
# N.B ExtractStack in tf_stack.cc will drop this frame prior to
# traversing the stack.
thread_key = _get_thread_key()
source_mappers = _source_mapper_stacks[thread_key]
# TODO(mdan): Use sentinels instead.
if source_mappers:
source_map = source_mappers[-1].get_effective_source_map()
else:
source_map = EMPTY_FROZEN_MAP
return _tf_stack.extract_stack(
limit,
_source_mapper_stacks[thread_key],
_source_filter_stacks[thread_key])
source_filters = _source_filter_stacks[thread_key]
if source_filters:
filtered_filenames = source_filters[-1].get_filtered_filenames()
else:
filtered_filenames = EMPTY_FROZEN_SET
while f is not None and (limit is None or length < limit):
lineno = f.f_lineno
co = f.f_code
filename = co.co_filename
name = co.co_name
frame_globals = f.f_globals
func_start_lineno = co.co_firstlineno
# TODO(mdan): Show some indication that the frame was translated.
filename, lineno, name = source_map.get(
(filename, lineno), (filename, lineno, name))
# Note: we never filter the innermost frame.
if not (ret and filename in filtered_filenames):
ret.append((filename, lineno, name, frame_globals, func_start_lineno))
length += 1
f = f.f_back
ret.reverse()
return ret
FileAndLine = collections.namedtuple('FileAndLine', ['file', 'line'])
def extract_stack_file_and_line(max_length=1000):
"""A version of extract_stack that only returns filenames and line numbers.
Callers often only require filenames and line numbers, and do not need the
additional information gathered by extract_stack, as they never call
convert_stack.
As a further optimisation, we allow users to specify a limit on the number of
frames examined.
Args:
max_length: The maximum length of stack to extract.
Returns:
A list of FileAndLine objects corresponding to the call stack of the current
thread.
"""
try:
raise ZeroDivisionError
except ZeroDivisionError:
frame = sys.exc_info()[2].tb_frame.f_back
ret = []
length = 0
while frame is not None and length < max_length:
ret.append(FileAndLine(frame.f_code.co_filename, frame.f_lineno))
length += 1
frame = frame.f_back
ret.reverse()
return ret
StackFrame = _tf_stack.StackFrame
def convert_stack(stack, include_func_start_lineno=False):
@ -251,16 +182,18 @@ def convert_stack(stack, include_func_start_lineno=False):
input tuple.
"""
def _tuple_generator(): # pylint: disable=missing-docstring
for (filename, lineno, name, frame_globals, func_start_lineno) in stack:
for frame in stack:
filename = frame.filename
lineno = frame.lineno
linecache.checkcache(filename)
line = linecache.getline(filename, lineno, frame_globals)
line = linecache.getline(filename, lineno, frame.globals)
if line:
line = line.strip()
else:
line = None
if include_func_start_lineno:
yield (filename, lineno, name, line, func_start_lineno)
yield (filename, lineno, frame.name, line, frame.func_start_lineno)
else:
yield (filename, lineno, name, line)
yield (filename, lineno, frame.name, line)
return tuple(_tuple_generator())

View File

@ -0,0 +1,55 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for functions used to extract and analyze stacks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import traceback
from tensorflow.python.platform import test
from tensorflow.python.util import tf_stack
class TFStackTest(test.TestCase):
def testLimit(self):
self.assertEmpty(tf_stack.extract_stack(limit=0))
self.assertLen(tf_stack.extract_stack(limit=1), 1)
self.assertEqual(
len(tf_stack.extract_stack(limit=-1)),
len(tf_stack.extract_stack()))
def testConsistencyWithTraceback(self):
stack, expected_stack = extract_stack()
for frame, expected in zip(stack, expected_stack):
self.assertEqual(frame, expected)
def testFormatStack(self):
stack, expected_stack = extract_stack()
self.assertEqual(
traceback.format_list(stack),
traceback.format_list(expected_stack))
def extract_stack(limit=None):
convert = tf_stack.convert_stack
# Both defined on the same line to produce identical stacks.
return convert(tf_stack.extract_stack(limit)), traceback.extract_stack(limit)
if __name__ == "__main__":
test.main()