Move stack more of the trace mappers and filters to C++. Simplify the API surface of extract_stack, and clean up its docstring.

PiperOrigin-RevId: 348459279
Change-Id: I1ae8afde96a7220df5fb6a95ace62f7c940a1e71
This commit is contained in:
Dan Moldovan 2020-12-21 06:38:42 -08:00 committed by TensorFlower Gardener
parent ad8083afac
commit 9b83b647b9
8 changed files with 209 additions and 176 deletions

View File

@ -1591,6 +1591,7 @@ py_library(
deps = [ deps = [
":platform", ":platform",
"//tensorflow/python/util", "//tensorflow/python/util",
# TODO(mdan): Remove this once the transitive dependency is fixed.
"//tensorflow/python/util:tf_stack", "//tensorflow/python/util:tf_stack",
], ],
) )

View File

@ -167,31 +167,35 @@ class StackTraceMapper(tf_stack.StackTraceMapper):
"""Remaps generated code to code it originated from.""" """Remaps generated code to code it originated from."""
def __init__(self, converted_fn): def __init__(self, converted_fn):
super().__init__()
self._source_map = converted_fn.ag_source_map self._source_map = converted_fn.ag_source_map
# This may be called repeatedly: once on entry, by the superclass, then by
# each child context manager.
self._cached_map = None
def get_effective_source_map(self): def get_effective_source_map(self):
effective_source_map = self._effective_source_map if self._cached_map is not None:
if effective_source_map is None: return self._cached_map
if self.parent is not None:
parent_map = self.parent.get_effective_source_map() parent_map = self.parent.get_effective_source_map()
effective_source_map = {}
for loc, origin in self._source_map.items():
effective_source_map[(loc.filename, loc.lineno)] = (origin.loc.filename,
origin.loc.lineno,
origin.function_name)
for key, value in parent_map.items():
filename, lineno, _ = value
value_loc = origin_info.LineLocation(filename=filename, lineno=lineno)
if value_loc in self._source_map:
origin = self._source_map[value_loc]
effective_source_map[key] = (origin.loc.filename, origin.loc.lineno,
origin.function_name)
else: else:
parent_map = {} effective_source_map[key] = value
effective_source_map = {} self._cached_map = effective_source_map
for loc, origin in self._source_map.items():
effective_source_map[(loc.filename, loc.lineno)] = (
origin.loc.filename, origin.loc.lineno, origin.function_name)
for key, value in parent_map.items():
filename, lineno, _ = value
value_loc = origin_info.LineLocation(filename=filename, lineno=lineno)
if value_loc in self._source_map:
origin = self._source_map[value_loc]
effective_source_map[key] = (
origin.loc.filename, origin.loc.lineno, origin.function_name)
else:
effective_source_map[key] = value
self._effective_source_map = effective_source_map
return effective_source_map return effective_source_map

View File

@ -23,9 +23,10 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import traceback
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import compat from tensorflow.python.util import compat
from tensorflow.python.util import tf_stack
# Registry mechanism below is based on mapreduce.python.mrpython.Register. # Registry mechanism below is based on mapreduce.python.mrpython.Register.
@ -65,8 +66,8 @@ class Registry(object):
logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name) logging.vlog(1, "Registering %s (%s) in %s.", name, candidate, self._name)
# stack trace is [this_function, Register(), user_function,...] # stack trace is [this_function, Register(), user_function,...]
# so the user function is #2. # so the user function is #2.
stack = tf_stack.extract_stack(limit=3) stack = traceback.extract_stack(limit=3)
stack_index = min(2, len(stack)-1) stack_index = min(2, len(stack) - 1)
if stack_index >= 0: if stack_index >= 0:
location_tag = stack[stack_index] location_tag = stack[stack_index]
else: else:

View File

@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from tensorflow.python.util import tf_stack import inspect
class TraceableObject(object): class TraceableObject(object):
@ -51,26 +51,20 @@ class TraceableObject(object):
TraceableObject.HEURISTIC_USED if the offset was larger than the stack, TraceableObject.HEURISTIC_USED if the offset was larger than the stack,
and TraceableObject.FAILURE if the stack was empty. and TraceableObject.FAILURE if the stack was empty.
""" """
# Offset is defined in "Args" as relative to the caller. We are one frame retcode = self.SUCCESS
frame = inspect.currentframe()
# Offset is defined in "Args" as relative to the caller. We are one frame
# beyond the caller. # beyond the caller.
local_offset = offset + 1 for _ in range(offset + 1):
parent = frame.f_back
frame_records = tf_stack.extract_stack( if parent is None:
limit=local_offset + 1) # If the offset is too large then we use the largest offset possible.
if not frame_records: retcode = self.HEURISTIC_USED
return self.FAILURE break
if len(frame_records) > local_offset: frame = parent
frame = frame_records[len(frame_records) - (local_offset + 1)] self.filename = frame.f_code.co_filename
self.filename = frame.filename self.lineno = frame.f_lineno
self.lineno = frame.lineno return retcode
return self.SUCCESS
else:
# If the offset is too large then we use the largest offset possible,
# meaning we use the outermost stack frame at index 0.
frame = frame_records[0]
self.filename = frame.filename
self.lineno = frame.lineno
return self.HEURISTIC_USED
def copy_metadata(self): def copy_metadata(self):
"""Return a TraceableObject like this one, but without the object.""" """Return a TraceableObject like this one, but without the object."""

View File

@ -19,12 +19,12 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import importlib import importlib
import inspect
import types import types
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect from tensorflow.python.util import tf_inspect
from tensorflow.python.util import tf_stack
from tensorflow.tools.compatibility import all_renames_v2 from tensorflow.tools.compatibility import all_renames_v2
@ -41,11 +41,12 @@ def _call_location():
# We want to get stack frame 3 frames up from current frame, # We want to get stack frame 3 frames up from current frame,
# i.e. above __getattr__, _tfmw_add_deprecation_warning, # i.e. above __getattr__, _tfmw_add_deprecation_warning,
# and _call_location calls. # and _call_location calls.
stack = tf_stack.extract_stack(limit=4) frame = inspect.currentframe()
if not stack: # should never happen as we're in a function for _ in range(4):
return 'UNKNOWN' parent = frame.f_back
frame = stack[0] if parent is None:
return '{}:{}'.format(frame.filename, frame.lineno) break
return '{}:{}'.format(frame.f_code.co_filename, frame.f_lineno)
def contains_deprecation_decorator(decorators): def contains_deprecation_decorator(decorators):

View File

@ -54,6 +54,30 @@ namespace {
namespace py = pybind11; namespace py = pybind11;
using SourceLoc = std::tuple<std::string, int>;
using SourceMap = absl::flat_hash_map<SourceLoc, StackFrame>;
using StringSet = absl::flat_hash_set<std::string>;
// Python wrapper for a SourceMap.
class PyBindSourceMap {
public:
PyBindSourceMap() : source_map_(std::make_shared<SourceMap>()) {}
// Shares ownership with whoever captures traces in the scope of this map.
std::shared_ptr<SourceMap> source_map_;
};
// Python wrapper for a FileSet.
class PyBindFileSet {
public:
PyBindFileSet() : file_set_(std::make_shared<StringSet>()) {}
// Shares ownership with whoever captures traces in the scope of this set.
std::shared_ptr<StringSet> file_set_;
};
// Returns contents of the line corresponding to the given frame. // Returns contents of the line corresponding to the given frame.
// //
// Precondition: must be holding Python GIL. // Precondition: must be holding Python GIL.
@ -98,36 +122,21 @@ std::string StackFrameToString(
class StackTraceWrapper : public AbstractStackTrace { class StackTraceWrapper : public AbstractStackTrace {
public: public:
StackTraceWrapper(StackTrace&& captured, const py::dict& source_map, StackTraceWrapper(StackTrace&& captured,
const py::set& filtered_filenames) const std::shared_ptr<SourceMap>& source_map,
const std::shared_ptr<StringSet>& filter)
: captured_(std::move(captured)), : captured_(std::move(captured)),
source_map_(source_map), source_map_(source_map),
filtered_filenames_(filtered_filenames) {} filter_(filter) {}
explicit StackTraceWrapper(absl::Span<StackFrame const> stack_frames) explicit StackTraceWrapper(absl::Span<StackFrame const> stack_frames)
: stack_frames_cache_(std::vector<StackFrame>(stack_frames.begin(), : stack_frames_cache_(std::vector<StackFrame>(stack_frames.begin(),
stack_frames.end())) {} stack_frames.end())) {}
static StackTraceWrapper ExtractStack(const py::object& limit, static StackTraceWrapper ExtractStack(
const py::list& mappers, const std::shared_ptr<SourceMap>& source_map,
const py::list& filters) { const std::shared_ptr<StringSet>& filter) {
// In Python 3.X ``traceback.extract_stack`` allows ``limit`` to return StackTraceWrapper{StackTrace::Capture(-1), source_map, filter};
// either be None or -1.
int casted_limit = limit.is_none() ? -1 : py::cast<ssize_t>(limit);
// Raise limit by one since we are dropping the last frame.
if (casted_limit != -1) casted_limit++;
const py::dict& source_map =
mappers.empty()
? py::dict()
: mappers[mappers.size() - 1].attr("get_effective_source_map")();
const py::set& filtered_filenames =
filters.empty()
? py::set()
: filters[filters.size() - 1].attr("get_filtered_filenames")();
return StackTraceWrapper{StackTrace::Capture(casted_limit), source_map,
filtered_filenames};
} }
absl::Span<StackFrame const> ToFrames() const override { absl::Span<StackFrame const> ToFrames() const override {
@ -169,6 +178,7 @@ class StackTraceWrapper : public AbstractStackTrace {
bool IsCacheGenerated() const { return stack_frames_cache_.has_value(); } bool IsCacheGenerated() const { return stack_frames_cache_.has_value(); }
void GenerateCache() const { void GenerateCache() const {
// TODO(mdan): We don't really need random access; this can be removed.
if (stack_frames_cache_) { if (stack_frames_cache_) {
return; return;
} }
@ -214,7 +224,7 @@ class StackTraceWrapper : public AbstractStackTrace {
PyGILState_STATE state = PyGILState_Ensure(); PyGILState_STATE state = PyGILState_Ensure();
captured_.Clear(); captured_.Clear();
source_map_.reset(); source_map_.reset();
filtered_filenames_.reset(); filter_.reset();
PyGILState_Release(state); PyGILState_Release(state);
} }
@ -237,33 +247,23 @@ class StackTraceWrapper : public AbstractStackTrace {
!absl::StrContains(file_name, "test.py"); !absl::StrContains(file_name, "test.py");
} }
absl::optional<StackFrame> StackTraceMapping( absl::optional<StackFrame> StackTraceMapping(SourceLoc loc) const {
std::pair<const char*, int> p) const { if (source_map_->contains(loc)) {
if (source_map_->empty()) { return source_map_->at(loc);
return absl::nullopt;
}
auto key = py::make_tuple(py::str(p.first), py::int_(p.second));
if (source_map_->contains(key)) {
const py::tuple& value = (*source_map_)[key];
return StackFrame{std::string(py::cast<py::str>(value[0])),
py::cast<py::int_>(value[1]),
std::string(py::cast<py::str>(value[2]))};
} }
return absl::nullopt; return absl::nullopt;
} }
bool StackTraceFiltering(const char* file_name) const { bool StackTraceFiltering(const char* file_name) const {
return filtered_filenames_->contains(file_name); return filter_->contains(file_name);
} }
StackTrace captured_; StackTrace captured_;
// Using optional to force destruction while we hold a GIL. std::shared_ptr<SourceMap> source_map_;
absl::optional<py::dict> source_map_; std::shared_ptr<StringSet> filter_;
absl::optional<py::set> filtered_filenames_;
// Using optional to force destruction while we hold a GIL.
mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_; mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_;
mutable absl::optional<StackFrame> last_stack_frame_cache_; mutable absl::optional<StackFrame> last_stack_frame_cache_;
}; };
@ -271,6 +271,42 @@ class StackTraceWrapper : public AbstractStackTrace {
} // namespace } // namespace
PYBIND11_MODULE(_tf_stack, m) { PYBIND11_MODULE(_tf_stack, m) {
py::class_<PyBindSourceMap>(m, "PyBindSourceMap")
.def(py::init())
.def("update_to",
[](const PyBindSourceMap& self, const py::tuple& source_map) {
self.source_map_->clear();
for (const auto& item : source_map) {
const auto& tuple_item = py::cast<py::tuple>(item);
const auto& key = py::cast<py::tuple>(tuple_item[0]);
std::string&& k_filename = py::cast<std::string>(key[0]);
int k_lineno = py::cast<int>(key[1]);
const auto& value = py::cast<py::tuple>(tuple_item[1]);
std::string&& v_filename = py::cast<std::string>(value[0]);
int v_lineno = py::cast<int>(value[1]);
const auto& function_name_val = value[2];
std::string&& v_function_name =
function_name_val.is_none()
? ""
: py::cast<std::string>(function_name_val);
self.source_map_->emplace(
SourceLoc(k_filename, k_lineno),
StackFrame({v_filename, v_lineno, v_function_name}));
}
});
py::class_<PyBindFileSet>(m, "PyBindFileSet")
.def(py::init())
.def("update_to", [](const PyBindFileSet& self, const py::set& file_set) {
self.file_set_->clear();
for (const auto& item : file_set) {
self.file_set_->insert(py::cast<std::string>(item));
}
});
py::class_<StackFrame>(m, "StackFrame") py::class_<StackFrame>(m, "StackFrame")
.def_property_readonly( .def_property_readonly(
"filename", "filename",
@ -369,22 +405,22 @@ PYBIND11_MODULE(_tf_stack, m) {
m.def( m.def(
"extract_stack_for_node", "extract_stack_for_node",
[](const py::object& limit, const py::list& mappers, [](const PyBindSourceMap& source_map, const PyBindFileSet& file_set,
const py::list& filters,
TF_Operation* op) -> const AbstractStackTrace& { TF_Operation* op) -> const AbstractStackTrace& {
Node* node = reinterpret_cast<Node*>(op); Node* node = reinterpret_cast<Node*>(op);
DCHECK(!node->GetStackTrace()) << "Should not reset the stack trace"; DCHECK(!node->GetStackTrace()) << "Should not reset the stack trace";
node->SetStackTrace(std::make_shared<StackTraceWrapper>( node->SetStackTrace(
StackTraceWrapper::ExtractStack(limit, mappers, filters))); std::make_shared<StackTraceWrapper>(StackTraceWrapper::ExtractStack(
source_map.source_map_, file_set.file_set_)));
return *node->GetStackTrace(); return *node->GetStackTrace();
}, },
py::return_value_policy::reference); py::return_value_policy::reference);
m.def( m.def(
"extract_stack", "extract_stack",
[](const py::object& limit, const py::list& mappers, [](const PyBindSourceMap& source_map, const PyBindFileSet& file_set) {
const py::list& filters) { return StackTraceWrapper::ExtractStack(source_map.source_map_,
return StackTraceWrapper::ExtractStack(limit, mappers, filters); file_set.file_set_);
}, },
py::return_value_policy::move); py::return_value_policy::move);
} }

View File

@ -40,8 +40,10 @@ else:
_get_thread_key = threading.get_ident _get_thread_key = threading.get_ident
_source_mapper_stacks = collections.defaultdict(list) # TODO(mdan): Move these to C++ as well.
_source_filter_stacks = collections.defaultdict(list) # Moving to C++ can further avoid extra copies made by get_effective_map.
_source_mapper_stacks = collections.defaultdict(lambda: [SentinelMapper()])
_source_filter_stacks = collections.defaultdict(lambda: [SentinelFilter()])
class StackTraceTransform(object): class StackTraceTransform(object):
@ -51,8 +53,6 @@ class StackTraceTransform(object):
_thread_key = None _thread_key = None
def __enter__(self): def __enter__(self):
self.reset()
# Any given instance is assumed to be used by a single thread, which reduces # Any given instance is assumed to be used by a single thread, which reduces
# expensive thread local lookups. # expensive thread local lookups.
if self._thread_key is None: if self._thread_key is None:
@ -61,48 +61,71 @@ class StackTraceTransform(object):
assert self._thread_key == _get_thread_key(), 'Shared across threads?' assert self._thread_key == _get_thread_key(), 'Shared across threads?'
stack = self._stack_dict[self._thread_key] stack = self._stack_dict[self._thread_key]
if stack: self.parent = stack[-1]
self.parent = stack[-1]
else:
self.parent = None
stack.append(self) stack.append(self)
self.update()
return self return self
def __exit__(self, unused_type, unused_value, unused_traceback): def __exit__(self, unused_type, unused_value, unused_traceback):
top = self._stack_dict[self._thread_key].pop() top = self._stack_dict[self._thread_key].pop()
assert top is self, 'Concurrent access?' assert top is self, 'Concurrent access?'
def reset(self): def update(self):
pass raise NotImplementedError('subclasses need to override this')
class StackTraceMapper(StackTraceTransform): class StackTraceMapper(StackTraceTransform):
"""Allows remapping traceback information to different source code.""" """Allows remapping traceback information to different source code."""
_stack_dict = _source_mapper_stacks _stack_dict = _source_mapper_stacks
def reset(self): def __init__(self):
self._effective_source_map = None self.internal_map = _tf_stack.PyBindSourceMap()
def update(self):
self.internal_map.update_to(tuple(self.get_effective_source_map().items()))
def get_effective_source_map(self): def get_effective_source_map(self):
"""Returns a map (filename, lineno) -> (filename, lineno, function_name).""" """Returns a map (filename, lineno) -> (filename, lineno, function_name)."""
raise NotImplementedError('subclasses need to override this') raise NotImplementedError('subclasses need to override this')
EMPTY_DICT = {}
class SentinelMapper(StackTraceMapper):
def get_effective_source_map(self):
return EMPTY_DICT
class StackTraceFilter(StackTraceTransform): class StackTraceFilter(StackTraceTransform):
"""Allows filtering traceback information by removing superfluous frames.""" """Allows filtering traceback information by removing superfluous frames."""
_stack_dict = _source_filter_stacks _stack_dict = _source_filter_stacks
def reset(self): def __init__(self):
self._filtered_filenames = None self.internal_set = _tf_stack.PyBindFileSet()
def update(self):
self.internal_set.update_to(set(self.get_filtered_filenames()))
def get_filtered_filenames(self): def get_filtered_filenames(self):
raise NotImplementedError('subclasses need to override this') raise NotImplementedError('subclasses need to override this')
EMPTY_SET = frozenset()
class SentinelFilter(StackTraceFilter):
def get_filtered_filenames(self):
return EMPTY_SET
class CurrentModuleFilter(StackTraceFilter): class CurrentModuleFilter(StackTraceFilter):
"""Filters stack frames from the module where this is used (best effort).""" """Filters stack frames from the module where this is used (best effort)."""
def __init__(self): def __init__(self):
super().__init__()
filter_filename = None filter_filename = None
outer_f = None outer_f = None
f = inspect.currentframe() f = inspect.currentframe()
@ -114,6 +137,9 @@ class CurrentModuleFilter(StackTraceFilter):
if outer_f is not None: if outer_f is not None:
filter_filename = inspect.getsourcefile(outer_f) filter_filename = inspect.getsourcefile(outer_f)
self._filename = filter_filename self._filename = filter_filename
# This may be called repeatedly: once on entry by the superclass, then by
# each child context manager.
self._cached_set = None
finally: finally:
# Avoid reference cycles, see: # Avoid reference cycles, see:
# https://docs.python.org/3.7/library/inspect.html#the-interpreter-stack # https://docs.python.org/3.7/library/inspect.html#the-interpreter-stack
@ -121,58 +147,52 @@ class CurrentModuleFilter(StackTraceFilter):
del outer_f del outer_f
def get_filtered_filenames(self): def get_filtered_filenames(self):
if self._filtered_filenames is None: if self._cached_set is not None:
self._filtered_filenames = frozenset((self._filename,)) return self._cached_set
if self.parent is not None:
self._filtered_filenames |= self.parent.get_filtered_filenames() filtered_filenames = frozenset((self._filename,))
return self._filtered_filenames if self.parent is not None:
filtered_filenames |= self.parent.get_filtered_filenames()
self._cached_set = filtered_filenames
return filtered_filenames
def extract_stack(limit=-1): def extract_stack():
"""A lightweight, extensible re-implementation of traceback.extract_stack. """An eager-friendly alternative to traceback.extract_stack.
NOTE(mrry): traceback.extract_stack eagerly retrieves the line of code for
each stack frame using linecache, which results in an abundance of stat()
calls. This implementation does not retrieve the code, and any consumer
should apply _convert_stack to the result to obtain a traceback that can
be formatted etc. using traceback methods.
Args:
limit: A limit on the number of frames to return.
Returns: Returns:
An object wrapping the sequence of StackFrame objects (filename, lineno, A list-like FrameSummary containing StackFrame-like objects, which are
name, line) corresponding to the call stack of the current thread. The namedtuple-like objects with the following fields: filename, lineno, name,
returned object can be indexed as a Python list. line, meant to masquerade as traceback.FrameSummary objects.
""" """
# N.B ExtractStack in tf_stack.cc will drop this frame prior to # N.B ExtractStack in tf_stack.cc will drop this frame prior to
# traversing the stack. # traversing the stack.
# TODO(cheshire): Remove this function, use extract_stack_for_node or Python # TODO(cheshire): Remove this function, use extract_stack_for_node or Python
# traceback module. # traceback module.
thread_key = _get_thread_key() thread_key = _get_thread_key()
return _tf_stack.extract_stack(limit, _source_mapper_stacks[thread_key], return _tf_stack.extract_stack(
_source_filter_stacks[thread_key]) _source_mapper_stacks[thread_key][-1].internal_map,
_source_filter_stacks[thread_key][-1].internal_set)
def extract_stack_for_node(node, limit=-1): # TODO(mdan): Revisit these - a single location is almost always sufficient.
"""Same as extract_stack, but also saves the retrieved stack in `node`. def extract_stack_for_node(node):
"""Attaches the current stack trace to `node`.
Args: Args:
node: Pointer to the Node object. node: a Node object.
limit: A limit on the number of frames to return.
Returns: Returns:
An object wrapping the sequence of StackFrame objects (filename, lineno, A list-like FrameSummary containing StackFrame-like objects, which are
name, line) corresponding to the call stack of the current thread. The namedtuple-like objects with the following fields: filename, lineno, name,
returned object can be indexed as a Python list. line, meant to masquerade as traceback.FrameSummary objects.
""" """
# N.B ExtractStack in tf_stack.cc will drop this frame prior to # N.B ExtractStack in tf_stack.cc will drop this frame prior to
# traversing the stack. # traversing the stack.
thread_key = _get_thread_key() thread_key = _get_thread_key()
return _tf_stack.extract_stack_for_node(limit, return _tf_stack.extract_stack_for_node(
_source_mapper_stacks[thread_key], _source_mapper_stacks[thread_key][-1].internal_map,
_source_filter_stacks[thread_key], _source_filter_stacks[thread_key][-1].internal_set, node)
node)
StackSummary = _tf_stack.StackTraceWrapper StackSummary = _tf_stack.StackTraceWrapper

View File

@ -26,31 +26,19 @@ from tensorflow.python.util import tf_stack
class TFStackTest(test.TestCase): class TFStackTest(test.TestCase):
def testLimit(self): def testFormatStackSelfConsistency(self):
self.assertEmpty(tf_stack.extract_stack(limit=0)) # Both defined on the same line to produce identical stacks.
self.assertLen(tf_stack.extract_stack(limit=1), 1) stacks = tf_stack.extract_stack(), traceback.extract_stack()
self.assertEqual( self.assertEqual(
len(tf_stack.extract_stack(limit=-1)), traceback.format_list(stacks[0]), traceback.format_list(stacks[1]))
len(tf_stack.extract_stack()))
def testConsistencyWithTraceback(self):
stack, expected_stack = extract_stack()
for frame, expected in zip(stack, expected_stack):
self.assertEqual(convert_stack_frame(frame), expected)
def testFormatStack(self):
stack, expected_stack = extract_stack()
self.assertEqual(
traceback.format_list(stack),
traceback.format_list(expected_stack))
def testFrameSummaryEquality(self): def testFrameSummaryEquality(self):
frame0, frame1 = tf_stack.extract_stack(limit=2) frames1 = tf_stack.extract_stack()
self.assertNotEqual(frame0, frame1) frames2 = tf_stack.extract_stack()
self.assertEqual(frame0, frame0)
another_frame0, _ = tf_stack.extract_stack(limit=2) self.assertNotEqual(frames1[0], frames1[1])
self.assertEqual(frame0, another_frame0) self.assertEqual(frames1[0], frames1[0])
self.assertEqual(frames1[0], frames2[0])
def testFrameSummaryEqualityAndHash(self): def testFrameSummaryEqualityAndHash(self):
# Both defined on the same line to produce identical stacks. # Both defined on the same line to produce identical stacks.
@ -74,17 +62,5 @@ def extract_stack(limit=None):
return tf_stack.extract_stack(limit), traceback.extract_stack(limit) return tf_stack.extract_stack(limit), traceback.extract_stack(limit)
def convert_stack_frame(frame):
"""Converts a TF stack frame into Python's."""
# TODO(mihaimaruseac): Remove except case when dropping suport for py2
try:
return traceback.FrameSummary(
frame.filename, frame.lineno, frame.name, line=frame.line)
except AttributeError:
# On Python < 3.5 (i.e., Python2), we don't have traceback.FrameSummary so
# we don't need to match with that class. Instead, just a tuple is enough.
return tuple(frame)
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()