Use generators in traceable_stack and add new peek_top_obj method

peek_objs and peek_traceable_objs are only used in ops.py. In most cases, they are iterated over, so there is no benefit to returning a list. In the other cases, just the first obj is required, so rather than returning the entire list, I add a peek_top_obj method.

Profile before:
ncalls  tottime  percall  cumtime  percall filename:lineno(function)
172162    0.171    0.000    0.171    0.000 traceable_stack.py:113(peek_objs)
197806    0.317    0.000    0.317    0.000 traceable_stack.py:117(peek_traceable_objs)

Profile after:
ncalls  tottime  percall  cumtime  percall filename:lineno(function)
143292    0.170    0.000    0.170    0.000 traceable_stack.py:117(peek_objs)
197806    0.080    0.000    0.080    0.000 traceable_stack.py:121(peek_traceable_objs)
 28870    0.008    0.000    0.008    0.000 traceable_stack.py:113(peek_top_obj)

PiperOrigin-RevId: 240965521
This commit is contained in:
James Keeling 2019-03-29 06:04:40 -07:00 committed by TensorFlower Gardener
parent c1b7d2be4b
commit 387bddab6a
3 changed files with 22 additions and 11 deletions

View File

@ -4468,11 +4468,11 @@ class Graph(object):
RuntimeError: If device scopes are not properly nested. RuntimeError: If device scopes are not properly nested.
""" """
self._add_device_to_stack(device_name_or_function, offset=2) self._add_device_to_stack(device_name_or_function, offset=2)
old_top_of_stack = self._device_function_stack.peek_objs()[0] old_top_of_stack = self._device_function_stack.peek_top_obj()
try: try:
yield yield
finally: finally:
new_top_of_stack = self._device_function_stack.peek_objs()[0] new_top_of_stack = self._device_function_stack.peek_top_obj()
if old_top_of_stack is not new_top_of_stack: if old_top_of_stack is not new_top_of_stack:
raise RuntimeError("Exiting device scope without proper scope nesting.") raise RuntimeError("Exiting device scope without proper scope nesting.")
self._device_function_stack.pop_obj() self._device_function_stack.pop_obj()
@ -5042,9 +5042,8 @@ class Graph(object):
the filename and lineno members point to the code location where the filename and lineno members point to the code location where
Graph.device was called directly or indirectly by the user. Graph.device was called directly or indirectly by the user.
""" """
traceable_objects = self._device_function_stack.peek_traceable_objs()
snapshot = [] snapshot = []
for obj in traceable_objects: for obj in self._device_function_stack.peek_traceable_objs():
obj_copy = obj.copy_metadata() obj_copy = obj.copy_metadata()
obj_copy.obj = obj.obj.display_name obj_copy.obj = obj.obj.display_name
snapshot.append(obj_copy) snapshot.append(obj_copy)
@ -5076,8 +5075,10 @@ class Graph(object):
def _snapshot_colocation_stack_metadata(self): def _snapshot_colocation_stack_metadata(self):
"""Return colocation stack metadata as a dictionary.""" """Return colocation stack metadata as a dictionary."""
traceable_objects = self._colocation_stack.peek_traceable_objs() return {
return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects} traceable_obj.obj.name: traceable_obj.copy_metadata()
for traceable_obj in self._colocation_stack.peek_traceable_objs()
}
@_colocation_stack.setter @_colocation_stack.setter
def _colocation_stack(self, colocation_stack): def _colocation_stack(self, colocation_stack):

View File

@ -110,13 +110,17 @@ class TraceableStack(object):
"""Remove last-inserted object and return it, without filename/line info.""" """Remove last-inserted object and return it, without filename/line info."""
return self._stack.pop().obj return self._stack.pop().obj
def peek_top_obj(self):
"""Return the most recent stored object."""
return self._stack[-1].obj
def peek_objs(self): def peek_objs(self):
"""Return list of stored objects ordered newest to oldest.""" """Return iterator over stored objects ordered newest to oldest."""
return [t_obj.obj for t_obj in reversed(self._stack)] return (t_obj.obj for t_obj in reversed(self._stack))
def peek_traceable_objs(self): def peek_traceable_objs(self):
"""Return list of stored TraceableObjects ordered newest to oldest.""" """Return iterator over stored TraceableObjects ordered newest to oldest."""
return list(reversed(self._stack)) return reversed(self._stack)
def __len__(self): def __len__(self):
"""Return number of items on the stack, and used for truth-value testing.""" """Return number of items on the stack, and used for truth-value testing."""

View File

@ -82,11 +82,17 @@ class TraceableStackTest(test_util.TensorFlowTestCase):
t_stack.push_obj('hope') t_stack.push_obj('hope')
expected_lifo_peek = ['hope', 42.0] expected_lifo_peek = ['hope', 42.0]
self.assertEqual(expected_lifo_peek, t_stack.peek_objs()) self.assertEqual(expected_lifo_peek, list(t_stack.peek_objs()))
self.assertEqual('hope', t_stack.pop_obj()) self.assertEqual('hope', t_stack.pop_obj())
self.assertEqual(42.0, t_stack.pop_obj()) self.assertEqual(42.0, t_stack.pop_obj())
def testPushPeekTopObj(self):
t_stack = traceable_stack.TraceableStack()
t_stack.push_obj(42.0)
t_stack.push_obj('hope')
self.assertEqual('hope', t_stack.peek_top_obj())
def testPushPopPreserveLifoOrdering(self): def testPushPopPreserveLifoOrdering(self):
t_stack = traceable_stack.TraceableStack() t_stack = traceable_stack.TraceableStack()
t_stack.push_obj(0) t_stack.push_obj(0)