Use generators in traceable_stack and add new peek_top_obj method
peek_objs and peek_traceable_objs are only used in ops.py. In most cases, they are iterated over, so there is no benefit to returning a list. In the other cases, just the first obj is required, so rather than returning the entire list, I add a peek_top_obj method. Profile before: ncalls tottime percall cumtime percall filename:lineno(function) 172162 0.171 0.000 0.171 0.000 traceable_stack.py:113(peek_objs) 197806 0.317 0.000 0.317 0.000 traceable_stack.py:117(peek_traceable_objs) Profile after: ncalls tottime percall cumtime percall filename:lineno(function) 143292 0.170 0.000 0.170 0.000 traceable_stack.py:117(peek_objs) 197806 0.080 0.000 0.080 0.000 traceable_stack.py:121(peek_traceable_objs) 28870 0.008 0.000 0.008 0.000 traceable_stack.py:113(peek_top_obj) PiperOrigin-RevId: 240965521
This commit is contained in:
parent
c1b7d2be4b
commit
387bddab6a
@ -4468,11 +4468,11 @@ class Graph(object):
|
|||||||
RuntimeError: If device scopes are not properly nested.
|
RuntimeError: If device scopes are not properly nested.
|
||||||
"""
|
"""
|
||||||
self._add_device_to_stack(device_name_or_function, offset=2)
|
self._add_device_to_stack(device_name_or_function, offset=2)
|
||||||
old_top_of_stack = self._device_function_stack.peek_objs()[0]
|
old_top_of_stack = self._device_function_stack.peek_top_obj()
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
new_top_of_stack = self._device_function_stack.peek_objs()[0]
|
new_top_of_stack = self._device_function_stack.peek_top_obj()
|
||||||
if old_top_of_stack is not new_top_of_stack:
|
if old_top_of_stack is not new_top_of_stack:
|
||||||
raise RuntimeError("Exiting device scope without proper scope nesting.")
|
raise RuntimeError("Exiting device scope without proper scope nesting.")
|
||||||
self._device_function_stack.pop_obj()
|
self._device_function_stack.pop_obj()
|
||||||
@ -5042,9 +5042,8 @@ class Graph(object):
|
|||||||
the filename and lineno members point to the code location where
|
the filename and lineno members point to the code location where
|
||||||
Graph.device was called directly or indirectly by the user.
|
Graph.device was called directly or indirectly by the user.
|
||||||
"""
|
"""
|
||||||
traceable_objects = self._device_function_stack.peek_traceable_objs()
|
|
||||||
snapshot = []
|
snapshot = []
|
||||||
for obj in traceable_objects:
|
for obj in self._device_function_stack.peek_traceable_objs():
|
||||||
obj_copy = obj.copy_metadata()
|
obj_copy = obj.copy_metadata()
|
||||||
obj_copy.obj = obj.obj.display_name
|
obj_copy.obj = obj.obj.display_name
|
||||||
snapshot.append(obj_copy)
|
snapshot.append(obj_copy)
|
||||||
@ -5076,8 +5075,10 @@ class Graph(object):
|
|||||||
|
|
||||||
def _snapshot_colocation_stack_metadata(self):
|
def _snapshot_colocation_stack_metadata(self):
|
||||||
"""Return colocation stack metadata as a dictionary."""
|
"""Return colocation stack metadata as a dictionary."""
|
||||||
traceable_objects = self._colocation_stack.peek_traceable_objs()
|
return {
|
||||||
return {obj.obj.name: obj.copy_metadata() for obj in traceable_objects}
|
traceable_obj.obj.name: traceable_obj.copy_metadata()
|
||||||
|
for traceable_obj in self._colocation_stack.peek_traceable_objs()
|
||||||
|
}
|
||||||
|
|
||||||
@_colocation_stack.setter
|
@_colocation_stack.setter
|
||||||
def _colocation_stack(self, colocation_stack):
|
def _colocation_stack(self, colocation_stack):
|
||||||
|
|||||||
@ -110,13 +110,17 @@ class TraceableStack(object):
|
|||||||
"""Remove last-inserted object and return it, without filename/line info."""
|
"""Remove last-inserted object and return it, without filename/line info."""
|
||||||
return self._stack.pop().obj
|
return self._stack.pop().obj
|
||||||
|
|
||||||
|
def peek_top_obj(self):
|
||||||
|
"""Return the most recent stored object."""
|
||||||
|
return self._stack[-1].obj
|
||||||
|
|
||||||
def peek_objs(self):
|
def peek_objs(self):
|
||||||
"""Return list of stored objects ordered newest to oldest."""
|
"""Return iterator over stored objects ordered newest to oldest."""
|
||||||
return [t_obj.obj for t_obj in reversed(self._stack)]
|
return (t_obj.obj for t_obj in reversed(self._stack))
|
||||||
|
|
||||||
def peek_traceable_objs(self):
|
def peek_traceable_objs(self):
|
||||||
"""Return list of stored TraceableObjects ordered newest to oldest."""
|
"""Return iterator over stored TraceableObjects ordered newest to oldest."""
|
||||||
return list(reversed(self._stack))
|
return reversed(self._stack)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""Return number of items on the stack, and used for truth-value testing."""
|
"""Return number of items on the stack, and used for truth-value testing."""
|
||||||
|
|||||||
@ -82,11 +82,17 @@ class TraceableStackTest(test_util.TensorFlowTestCase):
|
|||||||
t_stack.push_obj('hope')
|
t_stack.push_obj('hope')
|
||||||
|
|
||||||
expected_lifo_peek = ['hope', 42.0]
|
expected_lifo_peek = ['hope', 42.0]
|
||||||
self.assertEqual(expected_lifo_peek, t_stack.peek_objs())
|
self.assertEqual(expected_lifo_peek, list(t_stack.peek_objs()))
|
||||||
|
|
||||||
self.assertEqual('hope', t_stack.pop_obj())
|
self.assertEqual('hope', t_stack.pop_obj())
|
||||||
self.assertEqual(42.0, t_stack.pop_obj())
|
self.assertEqual(42.0, t_stack.pop_obj())
|
||||||
|
|
||||||
|
def testPushPeekTopObj(self):
|
||||||
|
t_stack = traceable_stack.TraceableStack()
|
||||||
|
t_stack.push_obj(42.0)
|
||||||
|
t_stack.push_obj('hope')
|
||||||
|
self.assertEqual('hope', t_stack.peek_top_obj())
|
||||||
|
|
||||||
def testPushPopPreserveLifoOrdering(self):
|
def testPushPopPreserveLifoOrdering(self):
|
||||||
t_stack = traceable_stack.TraceableStack()
|
t_stack = traceable_stack.TraceableStack()
|
||||||
t_stack.push_obj(0)
|
t_stack.push_obj(0)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user