Use generators in traceable_stack and add new peek_top_obj method
peek_objs and peek_traceable_objs are only used in ops.py. In most cases, they are iterated over, so there is no benefit to returning a list. In the other cases, just the first obj is required, so rather than returning the entire list, I add a peek_top_obj method. Profile before: ncalls tottime percall cumtime percall filename:lineno(function) 172162 0.171 0.000 0.171 0.000 traceable_stack.py:113(peek_objs) 197806 0.317 0.000 0.317 0.000 traceable_stack.py:117(peek_traceable_objs) Profile after: ncalls tottime percall cumtime percall filename:lineno(function) 143292 0.170 0.000 0.170 0.000 traceable_stack.py:117(peek_objs) 197806 0.080 0.000 0.080 0.000 traceable_stack.py:121(peek_traceable_objs) 28870 0.008 0.000 0.008 0.000 traceable_stack.py:113(peek_top_obj) PiperOrigin-RevId: 240965521
This commit is contained in:
parent
c1b7d2be4b
commit
387bddab6a
@ -4468,11 +4468,11 @@ class Graph(object):
|
||||
RuntimeError: If device scopes are not properly nested.
|
||||
"""
|
||||
self._add_device_to_stack(device_name_or_function, offset=2)
|
||||
old_top_of_stack = self._device_function_stack.peek_objs()[0]
|
||||
old_top_of_stack = self._device_function_stack.peek_top_obj()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
new_top_of_stack = self._device_function_stack.peek_objs()[0]
|
||||
new_top_of_stack = self._device_function_stack.peek_top_obj()
|
||||
if old_top_of_stack is not new_top_of_stack:
|
||||
raise RuntimeError("Exiting device scope without proper scope nesting.")
|
||||
self._device_function_stack.pop_obj()
|
||||
@ -5042,9 +5042,8 @@ class Graph(object):
|
||||
the filename and lineno members point to the code location where
|
||||
Graph.device was called directly or indirectly by the user.
|
||||
"""
|
||||
traceable_objects = self._device_function_stack.peek_traceable_objs()
|
||||
snapshot = []
|
||||
for obj in traceable_objects:
|
||||
for obj in self._device_function_stack.peek_traceable_objs():
|
||||
obj_copy = obj.copy_metadata()
|
||||
obj_copy.obj = obj.obj.display_name
|
||||
snapshot.append(obj_copy)
|
||||
@ -5076,8 +5075,10 @@ class Graph(object):
|
||||
|
||||
def _snapshot_colocation_stack_metadata(self):
|
||||
"""Return colocation stack metadata as a dictionary."""
|
||||
traceable_objects = self._colocation_stack.peek_traceable_objs()
|
||||
return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects}
|
||||
return {
|
||||
traceable_obj.obj.name: traceable_obj.copy_metadata()
|
||||
for traceable_obj in self._colocation_stack.peek_traceable_objs()
|
||||
}
|
||||
|
||||
@_colocation_stack.setter
|
||||
def _colocation_stack(self, colocation_stack):
|
||||
|
@ -110,13 +110,17 @@ class TraceableStack(object):
|
||||
"""Remove last-inserted object and return it, without filename/line info."""
|
||||
return self._stack.pop().obj
|
||||
|
||||
def peek_top_obj(self):
|
||||
"""Return the most recent stored object."""
|
||||
return self._stack[-1].obj
|
||||
|
||||
def peek_objs(self):
|
||||
"""Return list of stored objects ordered newest to oldest."""
|
||||
return [t_obj.obj for t_obj in reversed(self._stack)]
|
||||
"""Return iterator over stored objects ordered newest to oldest."""
|
||||
return (t_obj.obj for t_obj in reversed(self._stack))
|
||||
|
||||
def peek_traceable_objs(self):
|
||||
"""Return list of stored TraceableObjects ordered newest to oldest."""
|
||||
return list(reversed(self._stack))
|
||||
"""Return iterator over stored TraceableObjects ordered newest to oldest."""
|
||||
return reversed(self._stack)
|
||||
|
||||
def __len__(self):
|
||||
"""Return number of items on the stack, and used for truth-value testing."""
|
||||
|
@ -82,11 +82,17 @@ class TraceableStackTest(test_util.TensorFlowTestCase):
|
||||
t_stack.push_obj('hope')
|
||||
|
||||
expected_lifo_peek = ['hope', 42.0]
|
||||
self.assertEqual(expected_lifo_peek, t_stack.peek_objs())
|
||||
self.assertEqual(expected_lifo_peek, list(t_stack.peek_objs()))
|
||||
|
||||
self.assertEqual('hope', t_stack.pop_obj())
|
||||
self.assertEqual(42.0, t_stack.pop_obj())
|
||||
|
||||
def testPushPeekTopObj(self):
|
||||
t_stack = traceable_stack.TraceableStack()
|
||||
t_stack.push_obj(42.0)
|
||||
t_stack.push_obj('hope')
|
||||
self.assertEqual('hope', t_stack.peek_top_obj())
|
||||
|
||||
def testPushPopPreserveLifoOrdering(self):
|
||||
t_stack = traceable_stack.TraceableStack()
|
||||
t_stack.push_obj(0)
|
||||
|
Loading…
Reference in New Issue
Block a user