diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index ddad4266d68..e312a25f008 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -4468,11 +4468,11 @@ class Graph(object): RuntimeError: If device scopes are not properly nested. """ self._add_device_to_stack(device_name_or_function, offset=2) - old_top_of_stack = self._device_function_stack.peek_objs()[0] + old_top_of_stack = self._device_function_stack.peek_top_obj() try: yield finally: - new_top_of_stack = self._device_function_stack.peek_objs()[0] + new_top_of_stack = self._device_function_stack.peek_top_obj() if old_top_of_stack is not new_top_of_stack: raise RuntimeError("Exiting device scope without proper scope nesting.") self._device_function_stack.pop_obj() @@ -5042,9 +5042,8 @@ class Graph(object): the filename and lineno members point to the code location where Graph.device was called directly or indirectly by the user. """ - traceable_objects = self._device_function_stack.peek_traceable_objs() snapshot = [] - for obj in traceable_objects: + for obj in self._device_function_stack.peek_traceable_objs(): obj_copy = obj.copy_metadata() obj_copy.obj = obj.obj.display_name snapshot.append(obj_copy) @@ -5076,8 +5075,10 @@ class Graph(object): def _snapshot_colocation_stack_metadata(self): """Return colocation stack metadata as a dictionary.""" - traceable_objects = self._colocation_stack.peek_traceable_objs() - return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects} + return { + traceable_obj.obj.name: traceable_obj.copy_metadata() + for traceable_obj in self._colocation_stack.peek_traceable_objs() + } @_colocation_stack.setter def _colocation_stack(self, colocation_stack): diff --git a/tensorflow/python/framework/traceable_stack.py b/tensorflow/python/framework/traceable_stack.py index f12a2927ea9..0a0cda870fc 100644 --- a/tensorflow/python/framework/traceable_stack.py +++ b/tensorflow/python/framework/traceable_stack.py @@ -110,13 +110,17 @@ class TraceableStack(object): """Remove last-inserted object and return it, without filename/line info.""" return self._stack.pop().obj + def peek_top_obj(self): + """Return the most recent stored object.""" + return self._stack[-1].obj + def peek_objs(self): - """Return list of stored objects ordered newest to oldest.""" - return [t_obj.obj for t_obj in reversed(self._stack)] + """Return iterator over stored objects ordered newest to oldest.""" + return (t_obj.obj for t_obj in reversed(self._stack)) def peek_traceable_objs(self): - """Return list of stored TraceableObjects ordered newest to oldest.""" - return list(reversed(self._stack)) + """Return iterator over stored TraceableObjects ordered newest to oldest.""" + return reversed(self._stack) def __len__(self): """Return number of items on the stack, and used for truth-value testing.""" diff --git a/tensorflow/python/framework/traceable_stack_test.py b/tensorflow/python/framework/traceable_stack_test.py index 3e7876f6318..bdc90014c0d 100644 --- a/tensorflow/python/framework/traceable_stack_test.py +++ b/tensorflow/python/framework/traceable_stack_test.py @@ -82,11 +82,17 @@ class TraceableStackTest(test_util.TensorFlowTestCase): t_stack.push_obj('hope') expected_lifo_peek = ['hope', 42.0] - self.assertEqual(expected_lifo_peek, t_stack.peek_objs()) + self.assertEqual(expected_lifo_peek, list(t_stack.peek_objs())) self.assertEqual('hope', t_stack.pop_obj()) self.assertEqual(42.0, t_stack.pop_obj()) + def testPushPeekTopObj(self): + t_stack = traceable_stack.TraceableStack() + t_stack.push_obj(42.0) + t_stack.push_obj('hope') + self.assertEqual('hope', t_stack.peek_top_obj()) + def testPushPopPreserveLifoOrdering(self): t_stack = traceable_stack.TraceableStack() t_stack.push_obj(0)