Use generators in traceable_stack and add new peek_top_obj method

peek_objs and peek_traceable_objs are only used in ops.py. In most cases, they are iterated over, so there is no benefit to returning a list. In the other cases, just the first obj is required, so rather than returning the entire list, I add a peek_top_obj method.

Profile before:
ncalls  tottime  percall  cumtime  percall filename:lineno(function)
172162    0.171    0.000    0.171    0.000 traceable_stack.py:113(peek_objs)
197806    0.317    0.000    0.317    0.000 traceable_stack.py:117(peek_traceable_objs)

Profile after:
ncalls  tottime  percall  cumtime  percall filename:lineno(function)
143292    0.170    0.000    0.170    0.000 traceable_stack.py:117(peek_objs)
197806    0.080    0.000    0.080    0.000 traceable_stack.py:121(peek_traceable_objs)
 28870    0.008    0.000    0.008    0.000 traceable_stack.py:113(peek_top_obj)

PiperOrigin-RevId: 240965521
This commit is contained in:
James Keeling 2019-03-29 06:04:40 -07:00 committed by TensorFlower Gardener
parent c1b7d2be4b
commit 387bddab6a
3 changed files with 22 additions and 11 deletions

View File

@ -4468,11 +4468,11 @@ class Graph(object):
RuntimeError: If device scopes are not properly nested.
"""
self._add_device_to_stack(device_name_or_function, offset=2)
old_top_of_stack = self._device_function_stack.peek_objs()[0]
old_top_of_stack = self._device_function_stack.peek_top_obj()
try:
yield
finally:
new_top_of_stack = self._device_function_stack.peek_objs()[0]
new_top_of_stack = self._device_function_stack.peek_top_obj()
if old_top_of_stack is not new_top_of_stack:
raise RuntimeError("Exiting device scope without proper scope nesting.")
self._device_function_stack.pop_obj()
@ -5042,9 +5042,8 @@ class Graph(object):
the filename and lineno members point to the code location where
Graph.device was called directly or indirectly by the user.
"""
traceable_objects = self._device_function_stack.peek_traceable_objs()
snapshot = []
for obj in traceable_objects:
for obj in self._device_function_stack.peek_traceable_objs():
obj_copy = obj.copy_metadata()
obj_copy.obj = obj.obj.display_name
snapshot.append(obj_copy)
@ -5076,8 +5075,10 @@ class Graph(object):
def _snapshot_colocation_stack_metadata(self):
"""Return colocation stack metadata as a dictionary."""
traceable_objects = self._colocation_stack.peek_traceable_objs()
return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects}
return {
traceable_obj.obj.name: traceable_obj.copy_metadata()
for traceable_obj in self._colocation_stack.peek_traceable_objs()
}
@_colocation_stack.setter
def _colocation_stack(self, colocation_stack):

View File

@ -110,13 +110,17 @@ class TraceableStack(object):
"""Remove last-inserted object and return it, without filename/line info."""
return self._stack.pop().obj
def peek_top_obj(self):
"""Return the most recent stored object."""
return self._stack[-1].obj
def peek_objs(self):
"""Return list of stored objects ordered newest to oldest."""
return [t_obj.obj for t_obj in reversed(self._stack)]
"""Return iterator over stored objects ordered newest to oldest."""
return (t_obj.obj for t_obj in reversed(self._stack))
def peek_traceable_objs(self):
"""Return list of stored TraceableObjects ordered newest to oldest."""
return list(reversed(self._stack))
"""Return iterator over stored TraceableObjects ordered newest to oldest."""
return reversed(self._stack)
def __len__(self):
"""Return number of items on the stack, and used for truth-value testing."""

View File

@ -82,11 +82,17 @@ class TraceableStackTest(test_util.TensorFlowTestCase):
t_stack.push_obj('hope')
expected_lifo_peek = ['hope', 42.0]
self.assertEqual(expected_lifo_peek, t_stack.peek_objs())
self.assertEqual(expected_lifo_peek, list(t_stack.peek_objs()))
self.assertEqual('hope', t_stack.pop_obj())
self.assertEqual(42.0, t_stack.pop_obj())
def testPushPeekTopObj(self):
t_stack = traceable_stack.TraceableStack()
t_stack.push_obj(42.0)
t_stack.push_obj('hope')
self.assertEqual('hope', t_stack.peek_top_obj())
def testPushPopPreserveLifoOrdering(self):
t_stack = traceable_stack.TraceableStack()
t_stack.push_obj(0)